dnsrecon/core/scanner.py
overcuriousity 2d485c5703 it
2025-09-11 00:00:00 +02:00

630 lines
25 KiB
Python

"""
Main scanning orchestrator for DNSRecon.
Coordinates data gathering from multiple providers and builds the infrastructure graph.
"""
import threading
import time
import traceback
from typing import List, Set, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
from collections import defaultdict
from core.graph_manager import GraphManager, NodeType, RelationshipType
from core.logger import get_forensic_logger, new_session
from providers.crtsh_provider import CrtShProvider
from providers.dns_provider import DNSProvider
from providers.shodan_provider import ShodanProvider
from providers.virustotal_provider import VirusTotalProvider
from config import config
class ScanStatus:
"""Enumeration of scan statuses."""
IDLE = "idle"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
STOPPED = "stopped"
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
Manages multi-provider data gathering and graph construction with concurrent processing.
"""
def __init__(self):
"""Initialize scanner with all available providers and empty graph."""
print("Initializing Scanner instance...")
try:
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event() # Use a threading.Event for safer signaling
self.scan_thread = None
# Scanning progress tracking
self.total_indicators_found = 0
self.indicators_processed = 0
self.current_indicator = ""
# Concurrent processing configuration
self.max_workers = config.max_concurrent_requests
self.executor = None # Keep a reference to the executor
# Initialize providers
print("Calling _initialize_providers...")
self._initialize_providers()
# Initialize logger
print("Initializing forensic logger...")
self.logger = get_forensic_logger()
print("Scanner initialization complete")
except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc()
raise
def _initialize_providers(self) -> None:
"""Initialize all available providers based on configuration."""
self.providers = []
print("Initializing providers...")
# Always add free providers
free_providers = [
('crtsh', CrtShProvider),
('dns', DNSProvider)
]
for provider_name, provider_class in free_providers:
if config.is_provider_enabled(provider_name):
try:
provider = provider_class()
if provider.is_available():
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully")
else:
print(f"{provider_name.title()} provider is not available")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
# Add API key-dependent providers
api_providers = [
('shodan', ShodanProvider),
('virustotal', VirusTotalProvider)
]
for provider_name, provider_class in api_providers:
if config.is_provider_enabled(provider_name):
try:
provider = provider_class()
if provider.is_available():
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully")
else:
print(f"{provider_name.title()} provider is not available (API key required)")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
print(f"Initialized {len(self.providers)} providers")
def start_scan(self, target_domain: str, max_depth: int = 2) -> bool:
"""
Start a new reconnaissance scan with concurrent processing.
Args:
target_domain: Initial domain to investigate
max_depth: Maximum recursion depth
Returns:
bool: True if scan started successfully
"""
print(f"Scanner.start_scan called with target='{target_domain}', depth={max_depth}")
try:
if self.status == ScanStatus.RUNNING:
print("Scan already running, rejecting new scan")
return False
# Check if we have any providers
if not self.providers:
print("No providers available, cannot start scan")
return False
# Stop any existing scan thread
if self.scan_thread and self.scan_thread.is_alive():
print("Stopping existing scan thread...")
self.stop_event.set()
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
print("WARNING: Could not stop existing thread")
return False
# Reset state
print("Resetting scanner state...")
self.graph.clear()
self.current_target = target_domain.lower().strip()
self.max_depth = max_depth
self.current_depth = 0
self.stop_event.clear()
self.total_indicators_found = 0
self.indicators_processed = 0
self.current_indicator = self.current_target
# Start new forensic session
print("Starting new forensic session...")
self.logger = new_session()
# Start scan in separate thread
print("Starting scan thread...")
self.scan_thread = threading.Thread(
target=self._execute_scan_async,
args=(self.current_target, max_depth),
daemon=True
)
self.scan_thread.start()
return True
except Exception as e:
print(f"ERROR: Exception in start_scan: {e}")
traceback.print_exc()
return False
def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
"""
Execute the reconnaissance scan asynchronously with concurrent provider queries.
Args:
target_domain: Target domain to investigate
max_depth: Maximum recursion depth
"""
print(f"_execute_scan_async started for {target_domain} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
try:
print("Setting status to RUNNING")
self.status = ScanStatus.RUNNING
# Log scan start
enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
print(f"Logged scan start with providers: {enabled_providers}")
# Initialize with target domain
print(f"Adding target domain '{target_domain}' as initial node")
self.graph.add_node(target_domain, NodeType.DOMAIN)
# BFS-style exploration
current_level_domains = {target_domain}
processed_domains = set()
all_discovered_ips = set()
print(f"Starting BFS exploration...")
for depth in range(max_depth + 1):
if self.stop_event.is_set():
print(f"Stop requested at depth {depth}")
break
self.current_depth = depth
print(f"Processing depth level {depth} with {len(current_level_domains)} domains")
if not current_level_domains:
print("No domains to process at this level")
break
self.total_indicators_found += len(current_level_domains)
next_level_domains = set()
domain_results = self._process_domains_concurrent(current_level_domains, processed_domains)
for domain, discovered_domains, discovered_ips in domain_results:
if self.stop_event.is_set():
break
processed_domains.add(domain)
all_discovered_ips.update(discovered_ips)
if depth < max_depth:
for discovered_domain in discovered_domains:
if discovered_domain not in processed_domains:
next_level_domains.add(discovered_domain)
print(f"Adding {discovered_domain} to next level")
if self.stop_event.is_set():
break
if all_discovered_ips:
print(f"Processing {len(all_discovered_ips)} discovered IP addresses")
self._process_ips_concurrent(all_discovered_ips)
current_level_domains = next_level_domains
print(f"Completed depth {depth}, {len(next_level_domains)} domains for next level")
except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
if self.stop_event.is_set():
self.status = ScanStatus.STOPPED
print("Scan completed with STOPPED status")
else:
self.status = ScanStatus.COMPLETED
print("Scan completed with COMPLETED status")
self.logger.log_scan_complete()
self.executor.shutdown(wait=False, cancel_futures=True)
stats = self.graph.get_statistics()
print(f"Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Domains processed: {len(processed_domains)}")
print(f" - IPs discovered: {len(all_discovered_ips)}")
def _process_domains_concurrent(self, domains: Set[str], processed_domains: Set[str]) -> List[Tuple[str, Set[str], Set[str]]]:
"""
Process multiple domains concurrently using thread pool.
"""
results = []
domains_to_process = domains - processed_domains
if not domains_to_process:
return results
print(f"Processing {len(domains_to_process)} domains concurrently with {self.max_workers} workers")
future_to_domain = {
self.executor.submit(self._query_providers_for_domain, domain): domain
for domain in domains_to_process
}
for future in as_completed(future_to_domain):
if self.stop_event.is_set():
future.cancel()
continue
domain = future_to_domain[future]
try:
discovered_domains, discovered_ips = future.result()
results.append((domain, discovered_domains, discovered_ips))
self.indicators_processed += 1
print(f"Completed processing domain: {domain} ({len(discovered_domains)} domains, {len(discovered_ips)} IPs)")
except (Exception, CancelledError) as e:
print(f"Error processing domain {domain}: {e}")
return results
def _process_ips_concurrent(self, ips: Set[str]) -> None:
"""
Process multiple IP addresses concurrently.
"""
if not ips or self.stop_event.is_set():
return
print(f"Processing {len(ips)} IP addresses concurrently")
future_to_ip = {
self.executor.submit(self._query_providers_for_ip, ip): ip
for ip in ips
}
for future in as_completed(future_to_ip):
if self.stop_event.is_set():
future.cancel()
continue
ip = future_to_ip[future]
try:
future.result() # Just wait for completion
print(f"Completed processing IP: {ip}")
except (Exception, CancelledError) as e:
print(f"Error processing IP {ip}: {e}")
def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]:
"""
Query all enabled providers for information about a domain.
"""
print(f"Querying {len(self.providers)} providers for domain: {domain}")
discovered_domains = set()
discovered_ips = set()
relationships_by_type = defaultdict(list)
if not self.providers or self.stop_event.is_set():
return discovered_domains, discovered_ips
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider
for provider in self.providers
}
for future in as_completed(future_to_provider):
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships")
for rel in relationships:
relationships_by_type[rel[2]].append(rel)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for {domain}: {e}")
for rel_type, relationships in relationships_by_type.items():
if len(relationships) > config.large_entity_threshold and rel_type == RelationshipType.SAN_CERTIFICATE:
self._handle_large_entity(domain, relationships, rel_type, provider.get_name())
else:
for source, target, rel_type, confidence, raw_data in relationships:
# Determine if the target should create a new node
create_node = rel_type in [
RelationshipType.A_RECORD,
RelationshipType.AAAA_RECORD,
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.NS_RECORD,
RelationshipType.PTR_RECORD,
RelationshipType.SAN_CERTIFICATE
]
# Determine if the target should be subject to recursion
recurse = rel_type in [
RelationshipType.A_RECORD,
RelationshipType.AAAA_RECORD,
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.SAN_CERTIFICATE
]
if create_node:
target_node_type = NodeType.IP if self._is_valid_ip(target) else NodeType.DOMAIN
self.graph.add_node(target, target_node_type)
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added relationship: {source} -> {target} ({rel_type.relationship_name})")
else:
# For records that don't create nodes, we still want to log the relationship
self.logger.log_relationship_discovery(
source_node=source,
target_node=target,
relationship_type=rel_type.relationship_name,
confidence_score=confidence,
provider=provider.name,
raw_data=raw_data,
discovery_method=f"dns_{rel_type.name.lower()}_record"
)
if recurse:
if self._is_valid_ip(target):
discovered_ips.add(target)
elif self._is_valid_domain(target):
discovered_domains.add(target)
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs")
return discovered_domains, discovered_ips
def _handle_large_entity(self, source_domain: str, relationships: list, rel_type: RelationshipType, provider_name: str):
"""
Handles the creation of a large entity node when a threshold is exceeded.
"""
print(f"Large number of {rel_type.name} relationships for {source_domain}. Creating a large entity node.")
entity_name = f"Large collection of {rel_type.name} for {source_domain}"
self.graph.add_node(entity_name, NodeType.LARGE_ENTITY, metadata={"count": len(relationships)})
self.graph.add_edge(source_domain, entity_name, rel_type, 0.9, provider_name, {"info": "Aggregated node"})
def _query_providers_for_ip(self, ip: str) -> None:
"""
Query all enabled providers for information about an IP address.
"""
print(f"Querying {len(self.providers)} providers for IP: {ip}")
if not self.providers or self.stop_event.is_set():
return
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider
for provider in self.providers
}
for future in as_completed(future_to_provider):
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}")
for source, target, rel_type, confidence, raw_data in relationships:
if self._is_valid_domain(target):
target_node_type = NodeType.DOMAIN
elif target.startswith('AS'):
target_node_type = NodeType.ASN
else:
target_node_type = NodeType.IP
self.graph.add_node(source, NodeType.IP)
self.graph.add_node(target, target_node_type)
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
def _safe_provider_query_domain(self, provider, domain: str):
"""Safely query provider for domain with error handling."""
if self.stop_event.is_set():
return []
try:
return provider.query_domain(domain)
except Exception as e:
print(f"Provider {provider.get_name()} query_domain failed: {e}")
return []
def _safe_provider_query_ip(self, provider, ip: str):
"""Safely query provider for IP with error handling."""
if self.stop_event.is_set():
return []
try:
return provider.query_ip(ip)
except Exception as e:
print(f"Provider {provider.get_name()} query_ip failed: {e}")
return []
def stop_scan(self) -> bool:
"""
Request scan termination.
"""
try:
if self.status == ScanStatus.RUNNING:
self.stop_event.set()
print("Scan stop requested")
return True
print("No active scan to stop")
return False
except Exception as e:
print(f"ERROR: Exception in stop_scan: {e}")
traceback.print_exc()
return False
def get_scan_status(self) -> Dict[str, Any]:
"""
Get current scan status and progress.
"""
try:
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'total_indicators_found': self.total_indicators_found,
'indicators_processed': self.indicators_processed,
'progress_percentage': self._calculate_progress(),
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics()
}
except Exception as e:
print(f"ERROR: Exception in get_scan_status: {e}")
traceback.print_exc()
return {
'status': 'error',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'current_indicator': '',
'total_indicators_found': 0,
'indicators_processed': 0,
'progress_percentage': 0.0,
'enabled_providers': [],
'graph_statistics': {}
}
def _calculate_progress(self) -> float:
"""Calculate scan progress percentage."""
if self.total_indicators_found == 0:
return 0.0
return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100)
def get_graph_data(self) -> Dict[str, Any]:
"""
Get current graph data for visualization.
"""
return self.graph.get_graph_data()
def export_results(self) -> Dict[str, Any]:
"""
Export complete scan results including graph and audit trail.
"""
graph_data = self.graph.export_json()
audit_trail = self.logger.export_audit_trail()
provider_stats = {}
for provider in self.providers:
provider_stats[provider.get_name()] = provider.get_statistics()
export_data = {
'scan_metadata': {
'target_domain': self.current_target,
'max_depth': self.max_depth,
'final_status': self.status,
'total_indicators_processed': self.indicators_processed,
'enabled_providers': list(provider_stats.keys())
},
'graph_data': graph_data,
'forensic_audit': audit_trail,
'provider_statistics': provider_stats,
'scan_summary': self.logger.get_forensic_summary()
}
return export_data
def remove_provider(self, provider_name: str) -> bool:
"""
Remove a provider from the scanner.
"""
for i, provider in enumerate(self.providers):
if provider.get_name() == provider_name:
self.providers.pop(i)
return True
return False
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""
Get statistics for all providers.
"""
stats = {}
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
"""
if not domain or len(domain) > 253:
return False
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False
class ScannerProxy:
def __init__(self):
self._scanner = None
print("ScannerProxy initialized")
def __getattr__(self, name):
if self._scanner is None:
print("Creating new Scanner instance...")
self._scanner = Scanner()
print("Scanner instance created")
return getattr(self._scanner, name)
# Global scanner instance
scanner = ScannerProxy()