# dnsrecon/core/scanner.py import threading import traceback from typing import List, Set, Dict, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError from collections import defaultdict from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType, RelationshipType from core.logger import get_forensic_logger, new_session from utils.helpers import _is_valid_ip, _is_valid_domain from providers.crtsh_provider import CrtShProvider from providers.dns_provider import DNSProvider from providers.shodan_provider import ShodanProvider class ScanStatus: """Enumeration of scan statuses.""" IDLE = "idle" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" STOPPED = "stopped" class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. REFACTORED: Simplified recursion with forensic provider state tracking. """ def __init__(self, session_config=None): """Initialize scanner with session-specific configuration.""" print("Initializing Scanner instance...") try: # Use provided session config or create default if session_config is None: from core.session_config import create_session_config session_config = create_session_config() self.config = session_config self.graph = GraphManager() self.providers = [] self.status = ScanStatus.IDLE self.current_target = None self.current_depth = 0 self.max_depth = 2 self.stop_event = threading.Event() self.scan_thread = None # Scanning progress tracking self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = "" # Concurrent processing configuration self.max_workers = self.config.max_concurrent_requests self.executor = None # Provider eligibility mapping self.provider_eligibility = { 'dns': {'domains': True, 'ips': True}, 'crtsh': {'domains': True, 'ips': False}, 'shodan': {'domains': True, 'ips': True}, 'virustotal': {'domains': False, 'ips': False} # Disabled as requested } # Initialize providers with session config print("Calling _initialize_providers with session config...") self._initialize_providers() # Initialize logger print("Initializing forensic logger...") self.logger = get_forensic_logger() print("Scanner initialization complete") except Exception as e: print(f"ERROR: Scanner initialization failed: {e}") traceback.print_exc() raise def __getstate__(self): """Prepare object for pickling by excluding unpicklable attributes.""" state = self.__dict__.copy() # Remove unpicklable threading objects unpicklable_attrs = [ 'stop_event', 'scan_thread', 'executor' ] for attr in unpicklable_attrs: if attr in state: del state[attr] return state def __setstate__(self, state): """Restore object after unpickling by reconstructing threading objects.""" self.__dict__.update(state) # Reconstruct threading objects self.stop_event = threading.Event() self.scan_thread = None self.executor = None def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] print("Initializing providers with session config...") # Provider classes mapping provider_classes = { 'dns': DNSProvider, 'crtsh': CrtShProvider, 'shodan': ShodanProvider, # Skip virustotal as requested } for provider_name, provider_class in provider_classes.items(): if self.config.is_provider_enabled(provider_name): try: provider = provider_class(session_config=self.config) if provider.is_available(): provider.set_stop_event(self.stop_event) self.providers.append(provider) print(f"✓ {provider_name.title()} provider initialized successfully for session") else: print(f"✗ {provider_name.title()} provider is not available") except Exception as e: print(f"✗ Failed to initialize {provider_name.title()} provider: {e}") traceback.print_exc() print(f"Initialized {len(self.providers)} providers for session") def update_session_config(self, new_config) -> None: """Update session configuration and reinitialize providers.""" print("Updating session configuration...") self.config = new_config self.max_workers = self.config.max_concurrent_requests self._initialize_providers() print("Session configuration updated") def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: """Start a new reconnaissance scan with forensic tracking.""" print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"Initial scanner status: {self.status}") # Clean up previous scan thread if needed if self.scan_thread and self.scan_thread.is_alive(): print("A previous scan thread is still alive. Sending termination signal and waiting...") self.stop_scan() self.scan_thread.join(10.0) if self.scan_thread.is_alive(): print("ERROR: The previous scan thread is unresponsive and could not be stopped.") self.status = ScanStatus.FAILED return False print("Previous scan thread terminated successfully.") # Reset state for new scan self.status = ScanStatus.IDLE print("Scanner state is now clean for a new scan.") try: if not hasattr(self, 'providers') or not self.providers: print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan") return False print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}") if clear_graph: self.graph.clear() self.current_target = target_domain.lower().strip() self.max_depth = max_depth self.current_depth = 0 self.stop_event.clear() self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = self.current_target # Start new forensic session print(f"Starting new forensic session for scanner {id(self)}...") self.logger = new_session() # Start scan in separate thread print(f"Starting scan thread for scanner {id(self)}...") self.scan_thread = threading.Thread( target=self._execute_scan, args=(self.current_target, max_depth), daemon=True ) self.scan_thread.start() print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") return True except Exception as e: print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() self.status = ScanStatus.FAILED return False def _execute_scan(self, target_domain: str, max_depth: int) -> None: """Execute the reconnaissance scan with simplified recursion and forensic tracking.""" print(f"_execute_scan started for {target_domain} with depth {max_depth}") self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_targets = set() try: self.status = ScanStatus.RUNNING enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target_domain, max_depth, enabled_providers) self.graph.add_node(target_domain, NodeType.DOMAIN) self._initialize_provider_states(target_domain) current_level_targets = {target_domain} all_discovered_targets = {target_domain} for depth in range(max_depth + 1): if self.stop_event.is_set(): print(f"Stop requested at depth {depth}") break self.current_depth = depth targets_to_process = current_level_targets - processed_targets if not targets_to_process: print("No new targets to process at this level.") break print(f"Processing depth level {depth} with {len(targets_to_process)} new targets") self.total_indicators_found += len(targets_to_process) target_results = self._process_targets_concurrent_forensic( targets_to_process, processed_targets, all_discovered_targets, depth ) processed_targets.update(targets_to_process) next_level_targets = set() for _target, new_targets in target_results: all_discovered_targets.update(new_targets) next_level_targets.update(new_targets) current_level_targets = next_level_targets except Exception as e: print(f"ERROR: Scan execution failed with error: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: if self.stop_event.is_set(): self.status = ScanStatus.STOPPED else: self.status = ScanStatus.COMPLETED self.logger.log_scan_complete() self.executor.shutdown(wait=False, cancel_futures=True) stats = self.graph.get_statistics() print("Final scan statistics:") print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Targets processed: {len(processed_targets)}") def _initialize_provider_states(self, target: str) -> None: """Initialize provider states for forensic tracking.""" if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() return node_data = self.graph.graph.nodes[target] if 'metadata' not in node_data: node_data['metadata'] = {} if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} def _should_recurse_on_target(self, target: str, processed_targets: Set[str], all_discovered: Set[str]) -> bool: """ Simplified recursion logic: only recurse on valid IPs and domains that haven't been processed. FORENSIC PRINCIPLE: Clear, simple rules for what gets recursed. """ # Don't recurse on already processed targets if target in processed_targets: return False # Only recurse on valid IPs and domains if not (_is_valid_ip(target) or _is_valid_domain(target)): return False # Don't recurse on targets contained in large entities if self._is_in_large_entity(target): return False return True def _is_in_large_entity(self, target: str) -> bool: """Check if a target is contained within a large entity node.""" for node_id, node_data in self.graph.graph.nodes(data=True): if node_data.get('type') == NodeType.LARGE_ENTITY.value: metadata = node_data.get('metadata', {}) contained_nodes = metadata.get('nodes', []) if target in contained_nodes: return True return False def _process_targets_concurrent_forensic(self, targets: Set[str], processed_targets: Set[str], all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]: """Process multiple targets concurrently with forensic provider state tracking.""" results = [] targets_to_process = targets - processed_targets if not targets_to_process: return results print(f"Processing {len(targets_to_process)} targets concurrently with forensic tracking") future_to_target = { self.executor.submit(self._query_providers_forensic, target, current_depth): target for target in targets_to_process } for future in as_completed(future_to_target): if self.stop_event.is_set(): future.cancel() continue target = future_to_target[future] try: new_targets = future.result() results.append((target, new_targets)) self.indicators_processed += 1 print(f"Completed processing target: {target} (found {len(new_targets)} new targets)") except (Exception, CancelledError) as e: print(f"Error processing target {target}: {e}") self._log_target_processing_error(target, str(e)) return results def _query_providers_forensic(self, target: str, current_depth: int) -> Set[str]: """ Query providers for a target with forensic state tracking and simplified recursion. REFACTORED: Simplified logic with complete forensic audit trail. """ is_ip = _is_valid_ip(target) target_type = NodeType.IP if is_ip else NodeType.DOMAIN print(f"Querying providers for {target_type.value}: {target} at depth {current_depth}") # Initialize node and provider states self.graph.add_node(target, target_type) self._initialize_provider_states(target) new_targets = set() target_metadata = defaultdict(lambda: defaultdict(list)) # Determine eligible providers for this target eligible_providers = self._get_eligible_providers(target, is_ip) if not eligible_providers: self._log_no_eligible_providers(target, is_ip) return new_targets # Query each eligible provider with forensic tracking with ThreadPoolExecutor(max_workers=len(eligible_providers)) as provider_executor: future_to_provider = { provider_executor.submit(self._query_single_provider_forensic, provider, target, is_ip, current_depth): provider for provider in eligible_providers } for future in as_completed(future_to_provider): if self.stop_event.is_set(): future.cancel() continue provider = future_to_provider[future] try: provider_results = future.result() if provider_results: discovered_targets = self._process_provider_results_forensic( target, provider, provider_results, target_metadata, current_depth ) new_targets.update(discovered_targets) except (Exception, CancelledError) as e: self._log_provider_error(target, provider.get_name(), str(e)) for node_id, metadata_dict in target_metadata.items(): if self.graph.graph.has_node(node_id): node_is_ip = _is_valid_ip(node_id) node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN # This call updates the existing node with the new metadata self.graph.add_node(node_id, node_type_to_add, metadata=metadata_dict) return new_targets def _get_eligible_providers(self, target: str, is_ip: bool) -> List: """Get providers eligible for querying this target.""" eligible = [] target_key = 'ips' if is_ip else 'domains' for provider in self.providers: provider_name = provider.get_name() if provider_name in self.provider_eligibility: if self.provider_eligibility[provider_name][target_key]: # Check if we already queried this provider for this target if not self._already_queried_provider(target, provider_name): eligible.append(provider) else: print(f"Skipping {provider_name} for {target} - already queried") return eligible def _already_queried_provider(self, target: str, provider_name: str) -> bool: """Check if we already queried a provider for a target.""" if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() return False node_data = self.graph.graph.nodes[target] provider_states = node_data.get('metadata', {}).get('provider_states', {}) return provider_name in provider_states def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List: """Query a single provider with complete forensic logging.""" provider_name = provider.get_name() start_time = datetime.now(timezone.utc) print(f"Querying {provider_name} for {target}") # Log attempt self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}") try: # Perform the query if is_ip: results = provider.query_ip(target) else: results = provider.query_domain(target) # Track successful state self._update_provider_state(target, provider_name, 'success', len(results), None, start_time) print(f"✓ {provider_name} returned {len(results)} results for {target}") return results except Exception as e: # Track failed state self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) print(f"✗ {provider_name} failed for {target}: {e}") raise def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: str, start_time: datetime) -> None: """Update provider state in node metadata for forensic tracking.""" if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() return node_data = self.graph.graph.nodes[target] if 'metadata' not in node_data: node_data['metadata'] = {} if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} node_data['metadata']['provider_states'][provider_name] = { 'status': status, 'timestamp': start_time.isoformat(), 'results_count': results_count, 'error': error, 'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 } # Log to forensic trail self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)") def _process_provider_results_forensic(self, target: str, provider, results: List, target_metadata: Dict, current_depth: int) -> Set[str]: """Process provider results with large entity detection and forensic logging.""" provider_name = provider.get_name() discovered_targets = set() # Check for large entity threshold per provider if len(results) > self.config.large_entity_threshold: print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}") self._create_large_entity(target, provider_name, results, current_depth) # Large entities block recursion - return empty set return discovered_targets # Process each relationship dns_records_to_create = {} for source, rel_target, rel_type, confidence, raw_data in results: if self.stop_event.is_set(): break # Enhanced forensic logging for each relationship self.logger.log_relationship_discovery( source_node=source, target_node=rel_target, relationship_type=rel_type.relationship_name, confidence_score=confidence, provider=provider_name, raw_data=raw_data, discovery_method=f"{provider_name}_query_depth_{current_depth}" ) # Collect metadata for source node self._collect_node_metadata_forensic(source, provider_name, rel_type, rel_target, raw_data, target_metadata[source]) # Add nodes and edges based on target type if _is_valid_ip(rel_target): self.graph.add_node(rel_target, NodeType.IP) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added IP relationship: {source} -> {rel_target} ({rel_type.relationship_name})") discovered_targets.add(rel_target) elif rel_target.startswith('AS') and rel_target[2:].isdigit(): self.graph.add_node(rel_target, NodeType.ASN) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added ASN relationship: {source} -> {rel_target} ({rel_type.relationship_name})") elif _is_valid_domain(rel_target): self.graph.add_node(rel_target, NodeType.DOMAIN) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})") discovered_targets.add(rel_target) # *** NEW: Enrich the newly discovered domain *** self._collect_node_metadata_forensic(rel_target, provider_name, rel_type, source, raw_data, target_metadata[rel_target]) else: # Handle DNS record content self._handle_dns_record_content(source, rel_target, rel_type, confidence, raw_data, provider_name, dns_records_to_create) # Create DNS record nodes self._create_dns_record_nodes(dns_records_to_create) return discovered_targets def _create_large_entity(self, source: str, provider_name: str, results: List, current_depth: int) -> None: """Create a large entity node for forensic tracking.""" entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" # Extract targets from results targets = [rel[1] for rel in results if len(rel) > 1] # Determine node type node_type = 'unknown' if targets: if _is_valid_domain(targets[0]): node_type = 'domain' elif _is_valid_ip(targets[0]): node_type = 'ip' # Create large entity metadata metadata = { 'count': len(targets), 'nodes': targets, 'node_type': node_type, 'source_provider': provider_name, 'discovery_depth': current_depth, 'threshold_exceeded': self.config.large_entity_threshold, 'forensic_note': f'Large entity created due to {len(targets)} results from {provider_name}' } # Create the node and edge self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, metadata=metadata) # Use first result's relationship type for the edge if results: rel_type = results[0][2] self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, {'large_entity_info': f'Contains {len(targets)} {node_type}s'}) # Forensic logging self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}") def _collect_node_metadata_forensic(self, node_id: str, provider_name: str, rel_type: RelationshipType, target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None: """Collect and organize metadata for forensic tracking with enhanced logging.""" self.logger.logger.debug(f"Collecting metadata for {node_id} from {provider_name}: {rel_type.relationship_name}") if provider_name == 'dns': record_type = raw_data.get('query_type', 'UNKNOWN') value = raw_data.get('value', target) dns_entry = f"{record_type}: {value}" if dns_entry not in metadata.get('dns_records', []): metadata.setdefault('dns_records', []).append(dns_entry) elif provider_name == 'crtsh': if rel_type == RelationshipType.SAN_CERTIFICATE: domain_certs = raw_data.get('domain_certificates', {}) if node_id in domain_certs: cert_summary = domain_certs[node_id] metadata['certificate_data'] = cert_summary metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False) if target not in metadata.get('related_domains_san', []): metadata.setdefault('related_domains_san', []).append(target) elif provider_name == 'shodan': for key, value in raw_data.items(): if key not in metadata.get('shodan', {}) or not metadata.get('shodan', {}).get(key): metadata.setdefault('shodan', {})[key] = value if rel_type == RelationshipType.ASN_MEMBERSHIP: metadata['asn_data'] = { 'asn': target, 'description': raw_data.get('org', ''), 'isp': raw_data.get('isp', ''), 'country': raw_data.get('country', '') } def _handle_dns_record_content(self, source: str, rel_target: str, rel_type: RelationshipType, confidence: float, raw_data: Dict[str, Any], provider_name: str, dns_records: Dict) -> None: """Handle DNS record content with forensic tracking.""" dns_record_types = [ RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD, RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD, RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD, RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD, RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD ] if rel_type in dns_record_types: record_type = rel_type.relationship_name.upper().replace('_RECORD', '') record_content = rel_target.strip() content_hash = hash(record_content) & 0x7FFFFFFF dns_record_id = f"{record_type}:{content_hash}" if dns_record_id not in dns_records: dns_records[dns_record_id] = { 'content': record_content, 'type': record_type, 'domains': set(), 'raw_data': raw_data, 'provider_name': provider_name, 'confidence': confidence } dns_records[dns_record_id]['domains'].add(source) def _create_dns_record_nodes(self, dns_records: Dict) -> None: """Create DNS record nodes with forensic metadata.""" for dns_record_id, record_info in dns_records.items(): record_metadata = { 'record_type': record_info['type'], 'content': record_info['content'], 'content_hash': dns_record_id.split(':')[1], 'associated_domains': list(record_info['domains']), 'source_data': record_info['raw_data'], 'forensic_note': f"DNS record created from {record_info['provider_name']} query" } self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata) for domain_name in record_info['domains']: self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD, record_info['confidence'], record_info['provider_name'], record_info['raw_data']) # Forensic logging for DNS record creation self.logger.logger.info(f"DNS record node created: {dns_record_id} for {len(record_info['domains'])} domains") def _log_target_processing_error(self, target: str, error: str) -> None: """Log target processing errors for forensic trail.""" self.logger.logger.error(f"Target processing failed for {target}: {error}") def _log_provider_error(self, target: str, provider_name: str, error: str) -> None: """Log provider query errors for forensic trail.""" self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") def _log_no_eligible_providers(self, target: str, is_ip: bool) -> None: """Log when no providers are eligible for a target.""" target_type = 'IP' if is_ip else 'domain' self.logger.logger.warning(f"No eligible providers for {target_type}: {target}") def stop_scan(self) -> bool: """Request immediate scan termination with forensic logging.""" try: if not self.scan_thread or not self.scan_thread.is_alive(): print("No active scan thread to stop.") if self.status == ScanStatus.RUNNING: self.status = ScanStatus.STOPPED return False print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") self.logger.logger.info("Scan termination requested by user") self.status = ScanStatus.STOPPED self.stop_event.set() if self.executor: print("Shutting down executor with immediate cancellation...") self.executor.shutdown(wait=False, cancel_futures=True) print("Termination signal sent. The scan thread will stop shortly.") return True except Exception as e: print(f"ERROR: Exception in stop_scan: {e}") self.logger.logger.error(f"Error during scan termination: {e}") traceback.print_exc() return False def get_scan_status(self) -> Dict[str, Any]: """Get current scan status with forensic information.""" try: return { 'status': self.status, 'target_domain': self.current_target, 'current_depth': self.current_depth, 'max_depth': self.max_depth, 'current_indicator': self.current_indicator, 'total_indicators_found': self.total_indicators_found, 'indicators_processed': self.indicators_processed, 'progress_percentage': self._calculate_progress(), 'enabled_providers': [provider.get_name() for provider in self.providers], 'graph_statistics': self.graph.get_statistics() } except Exception as e: print(f"ERROR: Exception in get_scan_status: {e}") traceback.print_exc() return { 'status': 'error', 'target_domain': None, 'current_depth': 0, 'max_depth': 0, 'current_indicator': '', 'total_indicators_found': 0, 'indicators_processed': 0, 'progress_percentage': 0.0, 'enabled_providers': [], 'graph_statistics': {} } def _calculate_progress(self) -> float: """Calculate scan progress percentage.""" if self.total_indicators_found == 0: return 0.0 return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100) def get_graph_data(self) -> Dict[str, Any]: """Get current graph data for visualization.""" return self.graph.get_graph_data() def export_results(self) -> Dict[str, Any]: """Export complete scan results with forensic audit trail.""" graph_data = self.graph.export_json() audit_trail = self.logger.export_audit_trail() provider_stats = {} for provider in self.providers: provider_stats[provider.get_name()] = provider.get_statistics() export_data = { 'scan_metadata': { 'target_domain': self.current_target, 'max_depth': self.max_depth, 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, 'enabled_providers': list(provider_stats.keys()), 'forensic_note': 'Refactored scanner with simplified recursion and forensic tracking' }, 'graph_data': graph_data, 'forensic_audit': audit_trail, 'provider_statistics': provider_stats, 'scan_summary': self.logger.get_forensic_summary() } return export_data def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: """Get statistics for all providers with forensic information.""" stats = {} for provider in self.providers: stats[provider.get_name()] = provider.get_statistics() return stats