""" Main scanning orchestrator for DNSRecon. Coordinates data gathering from multiple providers and builds the infrastructure graph. """ import threading import traceback from typing import List, Set, Dict, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError from collections import defaultdict from core.graph_manager import GraphManager, NodeType, RelationshipType from core.logger import get_forensic_logger, new_session from utils.helpers import _is_valid_ip, _is_valid_domain from providers.crtsh_provider import CrtShProvider from providers.dns_provider import DNSProvider from providers.shodan_provider import ShodanProvider from providers.virustotal_provider import VirusTotalProvider class ScanStatus: """Enumeration of scan statuses.""" IDLE = "idle" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" STOPPED = "stopped" class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. Now supports per-session configuration for multi-user isolation. """ def __init__(self, session_config=None): """Initialize scanner with session-specific configuration.""" print("Initializing Scanner instance...") try: # Use provided session config or create default if session_config is None: from core.session_config import create_session_config session_config = create_session_config() self.config = session_config self.graph = GraphManager() self.providers = [] self.status = ScanStatus.IDLE self.current_target = None self.current_depth = 0 self.max_depth = 2 self.stop_event = threading.Event() self.scan_thread = None # Scanning progress tracking self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = "" # Concurrent processing configuration self.max_workers = self.config.max_concurrent_requests self.executor = None # Initialize providers with session config print("Calling _initialize_providers with session config...") self._initialize_providers() # Initialize logger print("Initializing forensic logger...") self.logger = get_forensic_logger() print("Scanner initialization complete") except Exception as e: print(f"ERROR: Scanner initialization failed: {e}") traceback.print_exc() raise def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] print("Initializing providers with session config...") # Always add free providers free_providers = [ ('crtsh', CrtShProvider), ('dns', DNSProvider) ] for provider_name, provider_class in free_providers: if self.config.is_provider_enabled(provider_name): try: # Pass session config to provider provider = provider_class(session_config=self.config) if provider.is_available(): # Set the stop event for cancellation support provider.set_stop_event(self.stop_event) self.providers.append(provider) print(f"✓ {provider_name.title()} provider initialized successfully for session") else: print(f"✗ {provider_name.title()} provider is not available") except Exception as e: print(f"✗ Failed to initialize {provider_name.title()} provider: {e}") traceback.print_exc() # Add API key-dependent providers api_providers = [ ('shodan', ShodanProvider), ('virustotal', VirusTotalProvider) ] for provider_name, provider_class in api_providers: if self.config.is_provider_enabled(provider_name): try: # Pass session config to provider provider = provider_class(session_config=self.config) if provider.is_available(): # Set the stop event for cancellation support provider.set_stop_event(self.stop_event) self.providers.append(provider) print(f"✓ {provider_name.title()} provider initialized successfully for session") else: print(f"✗ {provider_name.title()} provider is not available (API key required)") except Exception as e: print(f"✗ Failed to initialize {provider_name.title()} provider: {e}") traceback.print_exc() print(f"Initialized {len(self.providers)} providers for session") def update_session_config(self, new_config) -> None: """ Update session configuration and reinitialize providers. Args: new_config: New SessionConfig instance """ print("Updating session configuration...") self.config = new_config self.max_workers = self.config.max_concurrent_requests self._initialize_providers() print("Session configuration updated") def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: """ Start a new reconnaissance scan. Forcefully cleans up any previous scan thread before starting. """ print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"Initial scanner status: {self.status}") # If a thread is still alive from a previous scan, we must wait for it to die. if self.scan_thread and self.scan_thread.is_alive(): print("A previous scan thread is still alive. Sending termination signal and waiting...") self.stop_scan() self.scan_thread.join(10.0) # Wait up to 10 seconds if self.scan_thread.is_alive(): print("ERROR: The previous scan thread is unresponsive and could not be stopped. Please restart the application.") self.status = ScanStatus.FAILED return False print("Previous scan thread terminated successfully.") # Reset state for the new scan self.status = ScanStatus.IDLE print(f"Scanner state is now clean for a new scan.") try: # Check if we have any providers if not hasattr(self, 'providers') or not self.providers: print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan") return False print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}") if clear_graph: self.graph.clear() self.current_target = target_domain.lower().strip() self.max_depth = max_depth self.current_depth = 0 self.stop_event.clear() self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = self.current_target # Start new forensic session print(f"Starting new forensic session for scanner {id(self)}...") self.logger = new_session() # Start scan in separate thread print(f"Starting scan thread for scanner {id(self)}...") self.scan_thread = threading.Thread( target=self._execute_scan, args=(self.current_target, max_depth), daemon=True ) self.scan_thread.start() print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") return True except Exception as e: print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() self.status = ScanStatus.FAILED return False def _execute_scan(self, target_domain: str, max_depth: int) -> None: """ Execute the reconnaissance scan with concurrent provider queries. Args: target_domain: Target domain to investigate max_depth: Maximum recursion depth """ print(f"_execute_scan started for {target_domain} with depth {max_depth}") self.executor = ThreadPoolExecutor(max_workers=self.max_workers) try: print("Setting status to RUNNING") self.status = ScanStatus.RUNNING # Log scan start enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target_domain, max_depth, enabled_providers) print(f"Logged scan start with providers: {enabled_providers}") # Initialize with target domain print(f"Adding target domain '{target_domain}' as initial node") self.graph.add_node(target_domain, NodeType.DOMAIN) # BFS-style exploration current_level_targets = {target_domain} processed_targets = set() print("Starting BFS exploration...") for depth in range(max_depth + 1): if self.stop_event.is_set(): print(f"Stop requested at depth {depth}") break self.current_depth = depth print(f"Processing depth level {depth} with {len(current_level_targets)} targets") if not current_level_targets: print("No targets to process at this level") break self.total_indicators_found += len(current_level_targets) target_results = self._process_targets_concurrent(current_level_targets, processed_targets) next_level_targets = set() for target, new_targets in target_results: processed_targets.add(target) if depth < max_depth: for new_target in new_targets: if new_target not in processed_targets: next_level_targets.add(new_target) current_level_targets = next_level_targets print(f"Completed depth {depth}, {len(next_level_targets)} targets for next level") except Exception as e: print(f"ERROR: Scan execution failed with error: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: if self.stop_event.is_set(): self.status = ScanStatus.STOPPED print("Scan completed with STOPPED status") else: self.status = ScanStatus.COMPLETED print("Scan completed with COMPLETED status") self.logger.log_scan_complete() self.executor.shutdown(wait=False, cancel_futures=True) stats = self.graph.get_statistics() print("Final scan statistics:") print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Targets processed: {len(processed_targets)}") def _process_targets_concurrent(self, targets: Set[str], processed_targets: Set[str]) -> List[Tuple[str, Set[str]]]: """Process multiple targets (domains or IPs) concurrently using a thread pool.""" results = [] targets_to_process = targets - processed_targets if not targets_to_process: return results print(f"Processing {len(targets_to_process)} targets concurrently with {self.max_workers} workers") future_to_target = { self.executor.submit(self._query_providers_for_target, target): target for target in targets_to_process } for future in as_completed(future_to_target): if self.stop_event.is_set(): future.cancel() continue target = future_to_target[future] try: new_targets = future.result() results.append((target, new_targets)) self.indicators_processed += 1 print(f"Completed processing target: {target} (found {len(new_targets)} new targets)") except (Exception, CancelledError) as e: print(f"Error processing target {target}: {e}") return results def _query_providers_for_target(self, target: str) -> Set[str]: """ Query all enabled providers for information about a target (domain or IP) and collect comprehensive metadata. Creates appropriate node types and relationships based on discovered data. """ is_ip = _is_valid_ip(target) target_type = NodeType.IP if is_ip else NodeType.DOMAIN print(f"Querying {len(self.providers)} providers for {target_type.value}: {target}") new_targets = set() all_relationships = [] if not self.providers or self.stop_event.is_set(): return new_targets with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor: future_to_provider = { provider_executor.submit( self._safe_provider_query, provider, target, is_ip ): provider for provider in self.providers } for future in as_completed(future_to_provider): if self.stop_event.is_set(): future.cancel() continue provider = future_to_provider[future] try: relationships = future.result() print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for {target}") for rel in relationships: source, rel_target, rel_type, confidence, raw_data = rel enhanced_rel = (source, rel_target, rel_type, confidence, raw_data, provider.get_name()) all_relationships.append(enhanced_rel) except (Exception, CancelledError) as e: print(f"Provider {provider.get_name()} failed for {target}: {e}") # NEW Step 2: Group all targets by type and identify large entities discovered_targets_by_type = defaultdict(set) for _, rel_target, _, _, _, _ in all_relationships: if _is_valid_domain(rel_target): discovered_targets_by_type[NodeType.DOMAIN].add(rel_target) elif _is_valid_ip(rel_target): discovered_targets_by_type[NodeType.IP].add(rel_target) targets_to_skip = set() for node_type, targets in discovered_targets_by_type.items(): if len(targets) > self.config.large_entity_threshold: print(f"Large number of {node_type.value}s ({len(targets)}) found for {target}. Creating a large entity node.") first_rel = next((r for r in all_relationships if r[1] in targets), None) if first_rel: self._handle_large_entity(target, list(targets), first_rel[2], first_rel[5]) targets_to_skip.update(targets) # Step 3: Process all relationships to create/update nodes and edges target_metadata = defaultdict(lambda: defaultdict(list)) dns_records_to_create = {} for source, rel_target, rel_type, confidence, raw_data, provider_name in all_relationships: if self.stop_event.is_set(): break # Special handling for crt.sh to distribute certificate metadata if provider_name == 'crtsh' and 'domain_certificates' in raw_data: domain_certs = raw_data.get('domain_certificates', {}) for cert_domain, cert_summary in domain_certs.items(): if _is_valid_domain(cert_domain) and cert_domain not in targets_to_skip: self.graph.add_node(cert_domain, NodeType.DOMAIN, metadata={'certificate_data': cert_summary}) # General metadata collection self._collect_node_metadata(source, provider_name, rel_type, rel_target, raw_data, target_metadata[source]) # Add nodes and edges to the graph if rel_target in targets_to_skip: continue if _is_valid_ip(rel_target): self.graph.add_node(rel_target, NodeType.IP) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added IP relationship: {source} -> {rel_target} ({rel_type.relationship_name})") if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]: new_targets.add(rel_target) elif rel_target.startswith('AS') and rel_target[2:].isdigit(): self.graph.add_node(rel_target, NodeType.ASN) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added ASN relationship: {source} -> {rel_target} ({rel_type.relationship_name})") elif _is_valid_domain(rel_target): self.graph.add_node(rel_target, NodeType.DOMAIN) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})") recurse_types = [ RelationshipType.CNAME_RECORD, RelationshipType.MX_RECORD, RelationshipType.SAN_CERTIFICATE, RelationshipType.NS_RECORD, RelationshipType.PASSIVE_DNS ] if rel_type in recurse_types: new_targets.add(rel_target) else: # Handle DNS record content dns_record_types = [ RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD, RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD, RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD, RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD, RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD ] if rel_type in dns_record_types: record_type = rel_type.relationship_name.upper().replace('_RECORD', '') record_content = rel_target.strip() content_hash = hash(record_content) & 0x7FFFFFFF dns_record_id = f"{record_type}:{content_hash}" if dns_record_id not in dns_records_to_create: dns_records_to_create[dns_record_id] = { 'content': record_content, 'type': record_type, 'domains': set(), 'raw_data': raw_data, 'provider_name': provider_name, 'confidence': confidence } dns_records_to_create[dns_record_id]['domains'].add(source) # Step 4: Update the source node with its collected metadata if target in target_metadata: self.graph.add_node(target, target_type, metadata=dict(target_metadata[target])) # Step 5: Create DNS record nodes and edges for dns_record_id, record_info in dns_records_to_create.items(): record_metadata = { 'record_type': record_info['type'], 'content': record_info['content'], 'content_hash': dns_record_id.split(':')[1], 'associated_domains': list(record_info['domains']), 'source_data': record_info['raw_data'] } self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata) for domain_name in record_info['domains']: self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD, record_info['confidence'], record_info['provider_name'], record_info['raw_data']) return new_targets def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType, target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None: """ Collect and organize metadata for a node based on provider responses. """ if provider_name == 'dns': record_type = raw_data.get('query_type', 'UNKNOWN') value = raw_data.get('value', target) if record_type in ['TXT', 'SPF', 'CAA']: dns_entry = f"{record_type}: {value}" else: dns_entry = f"{record_type}: {value}" if dns_entry not in metadata.get('dns_records', []): metadata.setdefault('dns_records', []).append(dns_entry) elif provider_name == 'crtsh': if rel_type == RelationshipType.SAN_CERTIFICATE: domain_certs = raw_data.get('domain_certificates', {}) if node_id in domain_certs: cert_summary = domain_certs[node_id] metadata['certificate_data'] = cert_summary metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False) if target not in metadata.get('related_domains_san', []): metadata.setdefault('related_domains_san', []).append(target) shared_certs = raw_data.get('shared_certificates', []) if shared_certs and 'shared_certificate_details' not in metadata: metadata['shared_certificate_details'] = shared_certs elif provider_name == 'shodan': for key, value in raw_data.items(): if key not in metadata.get('shodan', {}) or not metadata.get('shodan', {}).get(key): metadata.setdefault('shodan', {})[key] = value elif provider_name == 'virustotal': for key, value in raw_data.items(): if key not in metadata.get('virustotal', {}) or not metadata.get('virustotal', {}).get(key): metadata.setdefault('virustotal', {})[key] = value if rel_type == RelationshipType.PASSIVE_DNS: passive_entry = f"Passive DNS: {target}" if passive_entry not in metadata.get('passive_dns', []): metadata.setdefault('passive_dns', []).append(passive_entry) if rel_type == RelationshipType.ASN_MEMBERSHIP: metadata['asn_data'] = { 'asn': target, 'description': raw_data.get('org', ''), 'isp': raw_data.get('isp', ''), 'country': raw_data.get('country', '') } def _handle_large_entity(self, source: str, targets: list, rel_type: RelationshipType, provider_name: str): """ Handles the creation of a large entity node when a threshold is exceeded. """ print(f"Large number of {rel_type.name} relationships for {source}. Creating a large entity node.") entity_name = f"Large collection of {rel_type.name} for {source}" node_type = 'unknown' if targets: if _is_valid_domain(targets[0]): node_type = 'domain' elif _is_valid_ip(targets[0]): node_type = 'ip' self.graph.add_node(entity_name, NodeType.LARGE_ENTITY, metadata={"count": len(targets), "nodes": targets, "node_type": node_type}) self.graph.add_edge(source, entity_name, rel_type, 0.9, provider_name, {"info": "Aggregated node"}) def _safe_provider_query(self, provider, target: str, is_ip: bool) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """Safely query a provider for a target with error handling.""" if self.stop_event.is_set(): return [] try: if is_ip: return provider.query_ip(target) else: return provider.query_domain(target) except Exception as e: print(f"Provider {provider.get_name()} query failed for {target}: {e}") return [] def stop_scan(self) -> bool: """ Request immediate scan termination. Acts on the thread's liveness, not just the 'RUNNING' status. """ try: if not self.scan_thread or not self.scan_thread.is_alive(): print("No active scan thread to stop.") # Cleanup state if inconsistent if self.status == ScanStatus.RUNNING: self.status = ScanStatus.STOPPED return False print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") self.status = ScanStatus.STOPPED self.stop_event.set() if self.executor: print("Shutting down executor with immediate cancellation...") self.executor.shutdown(wait=False, cancel_futures=True) print("Termination signal sent. The scan thread will stop shortly.") return True except Exception as e: print(f"ERROR: Exception in stop_scan: {e}") traceback.print_exc() return False def _force_stop_completion(self): """Force completion of stop operation after timeout.""" if self.status == ScanStatus.RUNNING: print("Forcing scan termination after timeout") self.status = ScanStatus.STOPPED self.logger.log_scan_complete() def get_scan_status(self) -> Dict[str, Any]: """ Get current scan status and progress. """ try: return { 'status': self.status, 'target_domain': self.current_target, 'current_depth': self.current_depth, 'max_depth': self.max_depth, 'current_indicator': self.current_indicator, 'total_indicators_found': self.total_indicators_found, 'indicators_processed': self.indicators_processed, 'progress_percentage': self._calculate_progress(), 'enabled_providers': [provider.get_name() for provider in self.providers], 'graph_statistics': self.graph.get_statistics() } except Exception as e: print(f"ERROR: Exception in get_scan_status: {e}") traceback.print_exc() return { 'status': 'error', 'target_domain': None, 'current_depth': 0, 'max_depth': 0, 'current_indicator': '', 'total_indicators_found': 0, 'indicators_processed': 0, 'progress_percentage': 0.0, 'enabled_providers': [], 'graph_statistics': {} } def _calculate_progress(self) -> float: """Calculate scan progress percentage.""" if self.total_indicators_found == 0: return 0.0 return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100) def get_graph_data(self) -> Dict[str, Any]: """ Get current graph data for visualization. """ return self.graph.get_graph_data() def export_results(self) -> Dict[str, Any]: """ Export complete scan results including graph and audit trail. """ graph_data = self.graph.export_json() audit_trail = self.logger.export_audit_trail() provider_stats = {} for provider in self.providers: provider_stats[provider.get_name()] = provider.get_statistics() export_data = { 'scan_metadata': { 'target_domain': self.current_target, 'max_depth': self.max_depth, 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, 'enabled_providers': list(provider_stats.keys()) }, 'graph_data': graph_data, 'forensic_audit': audit_trail, 'provider_statistics': provider_stats, 'scan_summary': self.logger.get_forensic_summary() } return export_data def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: """ Get statistics for all providers. """ stats = {} for provider in self.providers: stats[provider.get_name()] = provider.get_statistics() return stats