# dnsrecon/core/scanner.py import threading import traceback import time import os import importlib from typing import List, Set, Dict, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future from collections import defaultdict, deque from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType from core.logger import get_forensic_logger, new_session from core.task_manager import TaskManager, TaskType, ReconTask from utils.helpers import _is_valid_ip, _is_valid_domain from providers.base_provider import BaseProvider class ScanStatus: """Enumeration of scan statuses.""" IDLE = "idle" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" STOPPED = "stopped" class Scanner: """ Enhanced scanning orchestrator for DNSRecon passive reconnaissance. Now uses task-based completion model with comprehensive retry logic. """ def __init__(self, session_config=None): """Initialize scanner with session-specific configuration and task management.""" print("Initializing Enhanced Scanner instance...") try: # Use provided session config or create default if session_config is None: from core.session_config import create_session_config session_config = create_session_config() self.config = session_config self.graph = GraphManager() self.providers = [] self.status = ScanStatus.IDLE self.current_target = None self.current_depth = 0 self.max_depth = 2 self.stop_event = threading.Event() self.scan_thread = None self.session_id = None # Will be set by session manager self.current_scan_id = None # Track current scan ID # Task-based execution components self.task_manager = None # Will be initialized when needed self.max_workers = self.config.max_concurrent_requests # Enhanced progress tracking self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = "" self.scan_start_time = None self.scan_end_time = None # Initialize providers with session config print("Calling _initialize_providers with session config...") self._initialize_providers() # Initialize logger print("Initializing forensic logger...") self.logger = get_forensic_logger() print("Enhanced Scanner initialization complete") except Exception as e: print(f"ERROR: Enhanced Scanner initialization failed: {e}") traceback.print_exc() raise def __getstate__(self): """Prepare object for pickling by excluding unpicklable attributes.""" state = self.__dict__.copy() # Remove unpicklable threading objects unpicklable_attrs = [ 'stop_event', 'scan_thread', 'task_manager' ] for attr in unpicklable_attrs: if attr in state: del state[attr] # Handle providers separately to ensure they're picklable if 'providers' in state: # The providers should be picklable now, but let's ensure clean state for provider in state['providers']: if hasattr(provider, '_stop_event'): provider._stop_event = None return state def __setstate__(self, state): """Restore object after unpickling by reconstructing threading objects.""" self.__dict__.update(state) # Reconstruct threading objects self.stop_event = threading.Event() self.scan_thread = None self.task_manager = None # Re-set stop events for providers if hasattr(self, 'providers'): for provider in self.providers: if hasattr(provider, 'set_stop_event'): provider.set_stop_event(self.stop_event) def _is_stop_requested(self) -> bool: """ Enhanced stop signal checking that handles both local and Redis-based signals. """ # Check local threading event first (fastest) if self.stop_event.is_set(): return True # Check Redis-based stop signal if session ID is available if self.session_id: try: from core.session_manager import session_manager return session_manager.is_stop_requested(self.session_id) except Exception as e: print(f"Error checking Redis stop signal: {e}") # Fall back to local event return self.stop_event.is_set() return False def _set_stop_signal(self) -> None: """ Set stop signal both locally and in Redis. """ # Set local event self.stop_event.set() # Set Redis signal if session ID is available if self.session_id: try: from core.session_manager import session_manager session_manager.set_stop_signal(self.session_id) except Exception as e: print(f"Error setting Redis stop signal: {e}") def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] print("Initializing providers with session config...") provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') print(f"Looking for providers in: {provider_dir}") if not os.path.exists(provider_dir): print(f"ERROR: Provider directory does not exist: {provider_dir}") return provider_files = [f for f in os.listdir(provider_dir) if f.endswith('_provider.py') and not f.startswith('base')] print(f"Found provider files: {provider_files}") for filename in provider_files: module_name = f"providers.{filename[:-3]}" print(f"Attempting to load module: {module_name}") try: module = importlib.import_module(module_name) print(f" ✓ Module {module_name} loaded successfully") # Find provider classes in the module provider_classes_found = [] for attribute_name in dir(module): attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_classes_found.append((attribute_name, attribute)) print(f" Found provider classes: {[name for name, _ in provider_classes_found]}") for class_name, provider_class in provider_classes_found: try: # Create temporary instance to get provider name temp_provider = provider_class(session_config=self.config) provider_name = temp_provider.get_name() print(f" Provider {class_name} -> name: {provider_name}") # Check if enabled in config is_enabled = self.config.is_provider_enabled(provider_name) print(f" Provider {provider_name} enabled: {is_enabled}") if is_enabled: # Check if available (has API keys, etc.) is_available = temp_provider.is_available() print(f" Provider {provider_name} available: {is_available}") if is_available: # Set stop event and add to providers list temp_provider.set_stop_event(self.stop_event) self.providers.append(temp_provider) print(f" ✓ {temp_provider.get_display_name()} provider initialized successfully") else: print(f" - {temp_provider.get_display_name()} provider is not available (missing API key or other requirement)") else: print(f" - {temp_provider.get_display_name()} provider is disabled in config") except Exception as e: print(f" ✗ Failed to initialize provider class {class_name}: {e}") import traceback traceback.print_exc() except Exception as e: print(f" ✗ Failed to load module {module_name}: {e}") import traceback traceback.print_exc() print(f"Total providers initialized: {len(self.providers)}") for provider in self.providers: print(f" - {provider.get_display_name()} ({provider.get_name()})") if len(self.providers) == 0: print("WARNING: No providers were initialized!") elif len(self.providers) == 1 and self.providers[0].get_name() == 'dns': print("WARNING: Only DNS provider initialized - other providers may have failed to load") def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: """Start a new reconnaissance scan with task-based completion model.""" print(f"=== STARTING ENHANCED SCAN IN SCANNER {id(self)} ===") print(f"Session ID: {self.session_id}") print(f"Initial scanner status: {self.status}") print(f"Clear graph: {clear_graph}") # Generate scan ID based on clear_graph behavior import uuid if clear_graph: # NEW SCAN: Generate new ID and terminate existing scan print("NEW SCAN: Generating new scan ID and terminating existing scan") self.current_scan_id = str(uuid.uuid4())[:8] # Aggressive cleanup of previous scan if self.scan_thread and self.scan_thread.is_alive(): print("Terminating previous scan thread...") self._set_stop_signal() if self.task_manager: self.task_manager.stop_execution() self.scan_thread.join(timeout=8.0) if self.scan_thread.is_alive(): print("WARNING: Previous scan thread did not terminate cleanly") else: # ADD TO GRAPH: Keep existing scan ID if scan is running, or generate new one if self.status == ScanStatus.RUNNING and self.current_scan_id: print(f"ADD TO GRAPH: Keeping existing scan ID {self.current_scan_id}") # Don't terminate existing scan - we're adding to it else: print("ADD TO GRAPH: No active scan, generating new scan ID") self.current_scan_id = str(uuid.uuid4())[:8] print(f"Using scan ID: {self.current_scan_id}") # Reset state for new scan (but preserve graph if clear_graph=False) if clear_graph or self.status != ScanStatus.RUNNING: self.status = ScanStatus.IDLE self._update_session_state() try: if not hasattr(self, 'providers') or not self.providers: print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan") return False print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}") if clear_graph: self.graph.clear() self.current_target = target_domain.lower().strip() self.max_depth = max_depth self.current_depth = 0 # Clear stop signals only if starting new scan if clear_graph or self.status != ScanStatus.RUNNING: self.stop_event.clear() if self.session_id: from core.session_manager import session_manager session_manager.clear_stop_signal(self.session_id) self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = self.current_target self.scan_start_time = datetime.now(timezone.utc) self.scan_end_time = None self._update_session_state() # Initialize forensic session only for new scans if clear_graph: self.logger = new_session() # Start task-based scan thread print(f"Starting task-based scan thread with scan ID {self.current_scan_id}...") self.scan_thread = threading.Thread( target=self._execute_task_based_scan, args=(self.current_target, max_depth, self.current_scan_id), daemon=True ) self.scan_thread.start() print(f"=== ENHANCED SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") return True except Exception as e: print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self.scan_end_time = datetime.now(timezone.utc) self._update_session_state() return False def _execute_task_based_scan(self, target_domain: str, max_depth: int, scan_id: str) -> None: """Execute the reconnaissance scan using the task-based completion model.""" print(f"_execute_task_based_scan started for {target_domain} with depth {max_depth}, scan ID {scan_id}") try: self.status = ScanStatus.RUNNING self._update_session_state() enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target_domain, max_depth, enabled_providers) # Initialize task manager self.task_manager = TaskManager( self.providers, self.graph, self.logger, max_concurrent_tasks=self.max_workers ) # Add initial target to graph self.graph.add_node(target_domain, NodeType.DOMAIN) # Start task execution self.task_manager.start_execution(max_workers=self.max_workers) # Track processed targets to avoid duplicates processed_targets = set() # Task queue for breadth-first processing target_queue = deque([(target_domain, 0)]) # (target, depth) while target_queue: # Abort if scan ID changed (new scan started) if self.current_scan_id != scan_id: print(f"Scan aborted - ID mismatch (current: {self.current_scan_id}, expected: {scan_id})") break if self._is_stop_requested(): print("Stop requested, terminating task-based scan.") break target, depth = target_queue.popleft() if target in processed_targets or depth > max_depth: continue self.current_depth = depth self.current_indicator = target self._update_session_state() print(f"Processing target: {target} at depth {depth}") # Create tasks for all eligible providers task_ids = self.task_manager.create_provider_tasks(target, depth, self.providers) if task_ids: print(f"Created {len(task_ids)} tasks for target {target}") self.total_indicators_found += len(task_ids) self._update_session_state() processed_targets.add(target) # Wait for current batch of tasks to complete before processing next depth # This ensures we get all relationships before expanding further timeout_per_batch = 60 # 60 seconds per batch batch_start = time.time() while time.time() - batch_start < timeout_per_batch: if self._is_stop_requested() or self.current_scan_id != scan_id: break progress_report = self.task_manager.get_progress_report() stats = progress_report['statistics'] # Update progress tracking self.indicators_processed = stats['succeeded'] + stats['failed_permanent'] self._update_session_state() # Check if current batch is complete current_batch_complete = ( stats['pending'] == 0 and stats['running'] == 0 and stats['failed_retrying'] == 0 ) if current_batch_complete: print(f"Batch complete for {target}: {stats['succeeded']} succeeded, {stats['failed_permanent']} failed") break time.sleep(1.0) # Check every second # Collect new targets from completed successful tasks if depth < max_depth: new_targets = self._collect_new_targets_from_completed_tasks() for new_target in new_targets: if new_target not in processed_targets: target_queue.append((new_target, depth + 1)) print(f"Added new target for next depth: {new_target}") # Wait for all remaining tasks to complete print("Waiting for all tasks to complete...") final_completion = self.task_manager.wait_for_completion(timeout_seconds=300) if not final_completion: print("WARNING: Some tasks did not complete within timeout") # Final progress update final_report = self.task_manager.get_progress_report() final_stats = final_report['statistics'] print(f"Final task statistics:") print(f" - Total tasks: {final_stats['total_tasks']}") print(f" - Succeeded: {final_stats['succeeded']}") print(f" - Failed permanently: {final_stats['failed_permanent']}") print(f" - Completion rate: {final_stats['completion_rate']:.1f}%") # Determine final scan status if self.current_scan_id == scan_id: if self._is_stop_requested(): self.status = ScanStatus.STOPPED elif final_stats['failed_permanent'] > 0 and final_stats['succeeded'] == 0: self.status = ScanStatus.FAILED elif final_stats['completion_rate'] < 50.0: # Less than 50% success rate self.status = ScanStatus.FAILED else: self.status = ScanStatus.COMPLETED self.scan_end_time = datetime.now(timezone.utc) self._update_session_state() self.logger.log_scan_complete() else: print(f"Scan completed but ID mismatch - not updating final status") except Exception as e: print(f"ERROR: Task-based scan execution failed: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self.scan_end_time = datetime.now(timezone.utc) self.logger.logger.error(f"Task-based scan failed: {e}") finally: # Clean up task manager if self.task_manager: self.task_manager.stop_execution() # Final statistics graph_stats = self.graph.get_statistics() print("Final scan statistics:") print(f" - Total nodes: {graph_stats['basic_metrics']['total_nodes']}") print(f" - Total edges: {graph_stats['basic_metrics']['total_edges']}") print(f" - Targets processed: {len(processed_targets)}") def _collect_new_targets_from_completed_tasks(self) -> Set[str]: """Collect new targets from successfully completed tasks.""" new_targets = set() if not self.task_manager: return new_targets # Get task summaries to find successful tasks task_summaries = self.task_manager.task_queue.get_task_summaries() for task_summary in task_summaries: if task_summary['status'] == 'succeeded': task_id = task_summary['task_id'] task = self.task_manager.task_queue.tasks.get(task_id) if task and task.result and task.result.data: task_new_targets = task.result.data.get('new_targets', []) for target in task_new_targets: if _is_valid_domain(target) or _is_valid_ip(target): new_targets.add(target) return new_targets def _update_session_state(self) -> None: """ Update the scanner state in Redis for GUI updates. """ if self.session_id: try: from core.session_manager import session_manager success = session_manager.update_session_scanner(self.session_id, self) if not success: print(f"WARNING: Failed to update session state for {self.session_id}") except Exception as e: print(f"ERROR: Failed to update session state: {e}") def stop_scan(self) -> bool: """Request immediate scan termination with task manager cleanup.""" try: print("=== INITIATING ENHANCED SCAN TERMINATION ===") self.logger.logger.info("Enhanced scan termination requested by user") # Invalidate current scan ID to prevent stale updates old_scan_id = self.current_scan_id self.current_scan_id = None print(f"Invalidated scan ID {old_scan_id}") # Set stop signals self._set_stop_signal() self.status = ScanStatus.STOPPED self.scan_end_time = datetime.now(timezone.utc) # Immediately update GUI with stopped status self._update_session_state() # Stop task manager if running if self.task_manager: print("Stopping task manager...") self.task_manager.stop_execution() print("Enhanced termination signals sent. The scan will stop as soon as possible.") return True except Exception as e: print(f"ERROR: Exception in enhanced stop_scan: {e}") self.logger.logger.error(f"Error during enhanced scan termination: {e}") traceback.print_exc() return False def get_scan_status(self) -> Dict[str, Any]: """Get current scan status with enhanced task-based information.""" try: status = { 'status': self.status, 'target_domain': self.current_target, 'current_depth': self.current_depth, 'max_depth': self.max_depth, 'current_indicator': self.current_indicator, 'total_indicators_found': self.total_indicators_found, 'indicators_processed': self.indicators_processed, 'progress_percentage': self._calculate_progress(), 'enabled_providers': [provider.get_name() for provider in self.providers], 'graph_statistics': self.graph.get_statistics(), 'scan_duration_seconds': self._calculate_scan_duration(), 'scan_start_time': self.scan_start_time.isoformat() if self.scan_start_time else None, 'scan_end_time': self.scan_end_time.isoformat() if self.scan_end_time else None } # Add task manager statistics if available if self.task_manager: progress_report = self.task_manager.get_progress_report() status['task_statistics'] = progress_report['statistics'] status['task_details'] = { 'is_running': progress_report['is_running'], 'worker_count': progress_report['worker_count'], 'failed_tasks_count': len(progress_report['failed_tasks']) } # Update indicators processed from task statistics task_stats = progress_report['statistics'] status['indicators_processed'] = task_stats['succeeded'] + task_stats['failed_permanent'] # Recalculate progress based on task completion if task_stats['total_tasks'] > 0: task_completion_rate = (task_stats['succeeded'] + task_stats['failed_permanent']) / task_stats['total_tasks'] status['progress_percentage'] = min(100.0, task_completion_rate * 100.0) return status except Exception as e: print(f"ERROR: Exception in get_scan_status: {e}") traceback.print_exc() return { 'status': 'error', 'target_domain': None, 'current_depth': 0, 'max_depth': 0, 'current_indicator': '', 'total_indicators_found': 0, 'indicators_processed': 0, 'progress_percentage': 0.0, 'enabled_providers': [], 'graph_statistics': {}, 'scan_duration_seconds': 0, 'error': str(e) } def _calculate_progress(self) -> float: """Calculate scan progress percentage.""" if self.total_indicators_found == 0: return 0.0 return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100) def _calculate_scan_duration(self) -> float: """Calculate scan duration in seconds.""" if not self.scan_start_time: return 0.0 end_time = self.scan_end_time or datetime.now(timezone.utc) duration = (end_time - self.scan_start_time).total_seconds() return round(duration, 2) def get_graph_data(self) -> Dict[str, Any]: """Get current graph data for visualization.""" return self.graph.get_graph_data() def export_results(self) -> Dict[str, Any]: """Export complete scan results with enhanced task-based audit trail.""" graph_data = self.graph.export_json() audit_trail = self.logger.export_audit_trail() provider_stats = {} for provider in self.providers: provider_stats[provider.get_name()] = provider.get_statistics() export_data = { 'scan_metadata': { 'target_domain': self.current_target, 'max_depth': self.max_depth, 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, 'enabled_providers': list(provider_stats.keys()), 'session_id': self.session_id, 'scan_id': self.current_scan_id, 'scan_duration_seconds': self._calculate_scan_duration(), 'scan_start_time': self.scan_start_time.isoformat() if self.scan_start_time else None, 'scan_end_time': self.scan_end_time.isoformat() if self.scan_end_time else None }, 'graph_data': graph_data, 'forensic_audit': audit_trail, 'provider_statistics': provider_stats, 'scan_summary': self.logger.get_forensic_summary() } # Add task execution details if available if self.task_manager: progress_report = self.task_manager.get_progress_report() export_data['task_execution'] = { 'statistics': progress_report['statistics'], 'failed_tasks': progress_report['failed_tasks'], 'execution_summary': { 'total_tasks_created': progress_report['statistics']['total_tasks'], 'success_rate': progress_report['statistics']['completion_rate'], 'average_retries': self._calculate_average_retries(progress_report) } } return export_data def _calculate_average_retries(self, progress_report: Dict[str, Any]) -> float: """Calculate average retry attempts across all tasks.""" if not self.task_manager or not hasattr(self.task_manager.task_queue, 'tasks'): return 0.0 total_attempts = 0 task_count = 0 for task in self.task_manager.task_queue.tasks.values(): if hasattr(task, 'execution_history'): total_attempts += len(task.execution_history) task_count += 1 return round(total_attempts / task_count, 2) if task_count > 0 else 0.0 def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: """Get statistics for all providers with enhanced cache information.""" stats = {} for provider in self.providers: provider_stats = provider.get_statistics() # Add cache performance metrics if hasattr(provider, 'cache'): cache_performance = { 'cache_enabled': True, 'cache_directory': provider.cache.cache_dir, 'cache_expiry_hours': provider.cache.cache_expiry / 3600 } provider_stats.update(cache_performance) stats[provider.get_name()] = provider_stats return stats def get_provider_info(self) -> Dict[str, Dict[str, Any]]: """Get information about all available providers with enhanced details.""" info = {} provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" try: module = importlib.import_module(module_name) for attribute_name in dir(module): attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_class = attribute # Instantiate to get metadata, even if not fully configured temp_provider = provider_class(session_config=self.config) provider_name = temp_provider.get_name() # Find the actual provider instance if it exists, to get live stats live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) provider_info = { 'display_name': temp_provider.get_display_name(), 'requires_api_key': temp_provider.requires_api_key(), 'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(), 'enabled': self.config.is_provider_enabled(provider_name), 'rate_limit': self.config.get_rate_limit(provider_name), 'eligibility': temp_provider.get_eligibility() } # Add cache information if provider has caching if live_provider and hasattr(live_provider, 'cache'): provider_info['cache_info'] = { 'cache_enabled': True, 'cache_directory': live_provider.cache.cache_dir, 'cache_expiry_hours': live_provider.cache.cache_expiry / 3600 } info[provider_name] = provider_info except Exception as e: print(f"✗ Failed to get info for provider from {filename}: {e}") traceback.print_exc() return info