# dnsrecon-reduced/core/scanner.py import threading import traceback import time import os import importlib import redis from typing import List, Set, Dict, Any, Tuple, Optional from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future from collections import defaultdict from queue import PriorityQueue from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType from core.logger import get_forensic_logger, new_session from utils.helpers import _is_valid_ip, _is_valid_domain from providers.base_provider import BaseProvider from core.rate_limiter import GlobalRateLimiter class ScanStatus: """Enumeration of scan statuses.""" IDLE = "idle" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" STOPPED = "stopped" class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. """ def __init__(self, session_config=None): """Initialize scanner with session-specific configuration.""" print("Initializing Scanner instance...") try: # Use provided session config or create default if session_config is None: from core.session_config import create_session_config session_config = create_session_config() self.config = session_config self.graph = GraphManager() self.providers = [] self.status = ScanStatus.IDLE self.current_target = None self.current_depth = 0 self.max_depth = 2 self.stop_event = threading.Event() self.scan_thread = None self.session_id: Optional[str] = None # Will be set by session manager self.task_queue = PriorityQueue() self.target_retries = defaultdict(int) self.scan_failed_due_to_retries = False # **NEW**: Track currently processing tasks to prevent processing after stop self.currently_processing = set() self.processing_lock = threading.Lock() # Scanning progress tracking self.total_indicators_found = 0 self.indicators_processed = 0 self.indicators_completed = 0 self.tasks_re_enqueued = 0 self.total_tasks_ever_enqueued = 0 self.current_indicator = "" # Concurrent processing configuration self.max_workers = self.config.max_concurrent_requests self.executor = None # Initialize providers with session config print("Calling _initialize_providers with session config...") self._initialize_providers() # Initialize logger print("Initializing forensic logger...") self.logger = get_forensic_logger() # Initialize global rate limiter self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) print("Scanner initialization complete") except Exception as e: print(f"ERROR: Scanner initialization failed: {e}") traceback.print_exc() raise def _is_stop_requested(self) -> bool: """ Check if stop is requested using both local and Redis-based signals. This ensures reliable termination across process boundaries. """ # Check local threading event first (fastest) if self.stop_event.is_set(): return True # Check Redis-based stop signal if session ID is available if self.session_id: try: from core.session_manager import session_manager return session_manager.is_stop_requested(self.session_id) except Exception as e: print(f"Error checking Redis stop signal: {e}") # Fall back to local event return self.stop_event.is_set() return False def _set_stop_signal(self) -> None: """ Set stop signal both locally and in Redis. """ # Set local event self.stop_event.set() # Set Redis signal if session ID is available if self.session_id: try: from core.session_manager import session_manager session_manager.set_stop_signal(self.session_id) except Exception as e: print(f"Error setting Redis stop signal: {e}") def __getstate__(self): """Prepare object for pickling by excluding unpicklable attributes.""" state = self.__dict__.copy() # Remove unpicklable threading objects unpicklable_attrs = [ 'stop_event', 'scan_thread', 'executor', 'processing_lock', 'task_queue', 'rate_limiter', 'logger' ] for attr in unpicklable_attrs: if attr in state: del state[attr] # Handle providers separately to ensure they're picklable if 'providers' in state: for provider in state['providers']: if hasattr(provider, '_stop_event'): provider._stop_event = None return state def __setstate__(self, state): """Restore object after unpickling by reconstructing threading objects.""" self.__dict__.update(state) # Reconstruct threading objects self.stop_event = threading.Event() self.scan_thread = None self.executor = None self.processing_lock = threading.Lock() self.task_queue = PriorityQueue() self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) self.logger = get_forensic_logger() if not hasattr(self, 'providers') or not self.providers: print("Providers not found after loading session, re-initializing...") self._initialize_providers() if not hasattr(self, 'currently_processing'): self.currently_processing = set() # Re-set stop events for providers if hasattr(self, 'providers'): for provider in self.providers: if hasattr(provider, 'set_stop_event'): provider.set_stop_event(self.stop_event) def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] print("Initializing providers with session config...") provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" try: module = importlib.import_module(module_name) for attribute_name in dir(module): attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_class = attribute provider = provider_class(name=attribute_name, session_config=self.config) provider_name = provider.get_name() if self.config.is_provider_enabled(provider_name): if provider.is_available(): provider.set_stop_event(self.stop_event) self.providers.append(provider) print(f"✓ {provider.get_display_name()} provider initialized successfully for session") else: print(f"✗ {provider.get_display_name()} provider is not available") except Exception as e: print(f"✗ Failed to initialize provider from {filename}: {e}") traceback.print_exc() print(f"Initialized {len(self.providers)} providers for session") def update_session_config(self, new_config) -> None: """Update session configuration and reinitialize providers.""" print("Updating session configuration...") self.config = new_config self.max_workers = self.config.max_concurrent_requests self._initialize_providers() print("Session configuration updated") def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: """Start a new reconnaissance scan with proper cleanup of previous scans.""" print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"Session ID: {self.session_id}") print(f"Initial scanner status: {self.status}") self.total_tasks_ever_enqueued = 0 # **IMPROVED**: More aggressive cleanup of previous scan if self.scan_thread and self.scan_thread.is_alive(): print("A previous scan thread is still alive. Forcing termination...") # Set stop signals immediately self._set_stop_signal() self.status = ScanStatus.STOPPED # Clear all processing state with self.processing_lock: self.currently_processing.clear() self.task_queue = PriorityQueue() # Shutdown executor aggressively if self.executor: print("Shutting down executor forcefully...") self.executor.shutdown(wait=False, cancel_futures=True) self.executor = None # Wait for thread termination with shorter timeout print("Waiting for previous scan thread to terminate...") self.scan_thread.join(5.0) # Reduced from 10 seconds if self.scan_thread.is_alive(): print("WARNING: Previous scan thread is still alive after 5 seconds") # Continue anyway, but log the issue self.logger.logger.warning("Previous scan thread failed to terminate cleanly") # Reset state for new scan with proper forensic logging print("Resetting scanner state for new scan...") self.status = ScanStatus.IDLE self.stop_event.clear() # **NEW**: Clear Redis stop signal explicitly if self.session_id: from core.session_manager import session_manager session_manager.clear_stop_signal(self.session_id) with self.processing_lock: self.currently_processing.clear() self.task_queue = PriorityQueue() self.target_retries.clear() self.scan_failed_due_to_retries = False # Update session state immediately for GUI feedback self._update_session_state() print("Scanner state reset complete.") try: if not hasattr(self, 'providers') or not self.providers: print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan") return False print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}") if clear_graph: self.graph.clear() if force_rescan_target and self.graph.graph.has_node(force_rescan_target): print(f"Forcing rescan of {force_rescan_target}, clearing provider states.") node_data = self.graph.graph.nodes[force_rescan_target] if 'metadata' in node_data and 'provider_states' in node_data['metadata']: node_data['metadata']['provider_states'] = {} self.current_target = target.lower().strip() self.max_depth = max_depth self.current_depth = 0 self.total_indicators_found = 0 self.indicators_processed = 0 self.indicators_completed = 0 self.tasks_re_enqueued = 0 self.current_indicator = self.current_target # Update GUI with scan preparation state self._update_session_state() # Start new forensic session print(f"Starting new forensic session for scanner {id(self)}...") self.logger = new_session() # Start scan in a separate thread print(f"Starting scan thread for scanner {id(self)}...") self.scan_thread = threading.Thread( target=self._execute_scan, args=(self.current_target, max_depth), daemon=True ) self.scan_thread.start() print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") return True except Exception as e: print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self._update_session_state() return False def _get_priority(self, provider_name): rate_limit = self.config.get_rate_limit(provider_name) if rate_limit > 90: return 1 # Highest priority elif rate_limit > 50: return 2 else: return 3 # Lowest priority def _execute_scan(self, target: str, max_depth: int) -> None: """Execute the reconnaissance scan with proper termination handling.""" print(f"_execute_scan started for {target} with depth {max_depth}") self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_tasks = set() # Initial task population for the main target is_ip = _is_valid_ip(target) initial_providers = self._get_eligible_providers(target, is_ip, False) for provider in initial_providers: provider_name = provider.get_name() self.task_queue.put((self._get_priority(provider_name), (provider_name, target, 0))) self.total_tasks_ever_enqueued += 1 try: self.status = ScanStatus.RUNNING self._update_session_state() enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target, max_depth, enabled_providers) # Determine initial node type node_type = NodeType.IP if is_ip else NodeType.DOMAIN self.graph.add_node(target, node_type) self._initialize_provider_states(target) # Better termination checking in main loop while not self.task_queue.empty() and not self._is_stop_requested(): try: priority, (provider_name, target_item, depth) = self.task_queue.get() except IndexError: # Queue became empty during processing break task_tuple = (provider_name, target_item) if task_tuple in processed_tasks: continue if depth > max_depth: continue if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): self.task_queue.put((priority + 1, (provider_name, target_item, depth))) # Postpone continue with self.processing_lock: if self._is_stop_requested(): print(f"Stop requested before processing {target_item}") break self.currently_processing.add(target_item) try: self.current_depth = depth self.current_indicator = target_item self._update_session_state() if self._is_stop_requested(): print(f"Stop requested during processing setup for {target_item}") break provider = next((p for p in self.providers if p.get_name() == provider_name), None) if provider: new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth) if self._is_stop_requested(): print(f"Stop requested after querying providers for {target_item}") break if not success: self.target_retries[task_tuple] += 1 if self.target_retries[task_tuple] <= self.config.max_retries_per_target: print(f"Re-queueing task {task_tuple} (attempt {self.target_retries[task_tuple]})") self.task_queue.put((priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 self.total_tasks_ever_enqueued += 1 else: print(f"ERROR: Max retries exceeded for task {task_tuple}") self.scan_failed_due_to_retries = True self._log_target_processing_error(str(task_tuple), "Max retries exceeded") else: processed_tasks.add(task_tuple) self.indicators_completed += 1 if not self._is_stop_requested(): all_new_targets = new_targets.union(large_entity_members) for new_target in all_new_targets: is_ip_new = _is_valid_ip(new_target) eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) for p_new in eligible_providers_new: p_name_new = p_new.get_name() if (p_name_new, new_target) not in processed_tasks: new_depth = depth + 1 if new_target in new_targets else depth self.task_queue.put((self._get_priority(p_name_new), (p_name_new, new_target, new_depth))) self.total_tasks_ever_enqueued += 1 finally: with self.processing_lock: self.currently_processing.discard(target_item) if self._is_stop_requested(): print("Scan terminated due to stop request") self.logger.logger.info("Scan terminated by user request") elif self.task_queue.empty(): print("Scan completed - no more targets to process") self.logger.logger.info("Scan completed - all targets processed") except Exception as e: print(f"ERROR: Scan execution failed with error: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: with self.processing_lock: self.currently_processing.clear() if self._is_stop_requested(): self.status = ScanStatus.STOPPED elif self.scan_failed_due_to_retries: self.status = ScanStatus.FAILED else: self.status = ScanStatus.COMPLETED self._update_session_state() self.logger.log_scan_complete() if self.executor: self.executor.shutdown(wait=False, cancel_futures=True) self.executor = None stats = self.graph.get_statistics() print("Final scan statistics:") print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Tasks processed: {len(processed_tasks)}") def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: if self._is_stop_requested(): print(f"Stop requested before querying {provider.get_name()} for {target}") return set(), set(), False is_ip = _is_valid_ip(target) target_type = NodeType.IP if is_ip else NodeType.DOMAIN print(f"Querying {provider.get_name()} for {target_type.value}: {target} at depth {depth}") self.graph.add_node(target, target_type) self._initialize_provider_states(target) new_targets = set() large_entity_members = set() node_attributes = defaultdict(lambda: defaultdict(list)) provider_successful = True try: provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth) if provider_results is None: provider_successful = False elif not self._is_stop_requested(): discovered, is_large_entity = self._process_provider_results_forensic( target, provider, provider_results, node_attributes, depth ) if is_large_entity: large_entity_members.update(discovered) else: new_targets.update(discovered) else: print(f"Stop requested after processing results from {provider.get_name()}") except Exception as e: provider_successful = False self._log_provider_error(target, provider.get_name(), str(e)) if not self._is_stop_requested(): for node_id, attributes in node_attributes.items(): if self.graph.graph.has_node(node_id): node_is_ip = _is_valid_ip(node_id) node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN self.graph.add_node(node_id, node_type_to_add, attributes=attributes) return new_targets, large_entity_members, provider_successful def stop_scan(self) -> bool: """Request immediate scan termination with proper cleanup.""" try: print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") self.logger.logger.info("Scan termination requested by user") # **IMPROVED**: More aggressive stop signal setting self._set_stop_signal() self.status = ScanStatus.STOPPED # **NEW**: Clear processing state immediately with self.processing_lock: currently_processing_copy = self.currently_processing.copy() self.currently_processing.clear() print(f"Cleared {len(currently_processing_copy)} currently processing targets: {currently_processing_copy}") # **IMPROVED**: Clear task queue and log what was discarded discarded_tasks = [] while not self.task_queue.empty(): discarded_tasks.append(self.task_queue.get()) self.task_queue = PriorityQueue() print(f"Discarded {len(discarded_tasks)} pending tasks") # **IMPROVED**: Aggressively shut down executor if self.executor: print("Shutting down executor with immediate cancellation...") try: # Cancel all pending futures self.executor.shutdown(wait=False, cancel_futures=True) print("Executor shutdown completed") except Exception as e: print(f"Error during executor shutdown: {e}") # Immediately update GUI with stopped status self._update_session_state() print("Termination signals sent. The scan will stop as soon as possible.") return True except Exception as e: print(f"ERROR: Exception in stop_scan: {e}") self.logger.logger.error(f"Error during scan termination: {e}") traceback.print_exc() return False def _update_session_state(self) -> None: """ Update the scanner state in Redis for GUI updates. This ensures the web interface sees real-time updates. """ if self.session_id: try: from core.session_manager import session_manager success = session_manager.update_session_scanner(self.session_id, self) if not success: print(f"WARNING: Failed to update session state for {self.session_id}") except Exception as e: print(f"ERROR: Failed to update session state: {e}") def get_scan_status(self) -> Dict[str, Any]: """Get current scan status with processing information.""" try: with self.processing_lock: currently_processing_count = len(self.currently_processing) currently_processing_list = list(self.currently_processing) return { 'status': self.status, 'target_domain': self.current_target, 'current_depth': self.current_depth, 'max_depth': self.max_depth, 'current_indicator': self.current_indicator, 'indicators_processed': self.indicators_processed, 'indicators_completed': self.indicators_completed, 'tasks_re_enqueued': self.tasks_re_enqueued, 'progress_percentage': self._calculate_progress(), 'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued, 'enabled_providers': [provider.get_name() for provider in self.providers], 'graph_statistics': self.graph.get_statistics(), 'task_queue_size': self.task_queue.qsize(), 'currently_processing_count': currently_processing_count, 'currently_processing': currently_processing_list[:5] } except Exception as e: print(f"ERROR: Exception in get_scan_status: {e}") traceback.print_exc() return { 'status': 'error', 'target_domain': None, 'current_depth': 0, 'max_depth': 0, 'current_indicator': '', 'indicators_processed': 0, 'indicators_completed': 0, 'tasks_re_enqueued': 0, 'progress_percentage': 0.0, 'enabled_providers': [], 'graph_statistics': {}, 'task_queue_size': 0, 'currently_processing_count': 0, 'currently_processing': [] } def _initialize_provider_states(self, target: str) -> None: """Initialize provider states for forensic tracking.""" if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] if 'metadata' not in node_data: node_data['metadata'] = {} if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: """Get providers eligible for querying this target.""" if dns_only: return [p for p in self.providers if p.get_name() == 'dns'] eligible = [] target_key = 'ips' if is_ip else 'domains' for provider in self.providers: if provider.get_eligibility().get(target_key): if not self._already_queried_provider(target, provider.get_name()): eligible.append(provider) else: print(f"Skipping {provider.get_name()} for {target} - already queried") return eligible def _already_queried_provider(self, target: str, provider_name: str) -> bool: """Check if we already successfully queried a provider for a target.""" if not self.graph.graph.has_node(target): return False node_data = self.graph.graph.nodes[target] provider_states = node_data.get('metadata', {}).get('provider_states', {}) # A provider has been successfully queried if a state exists and its status is 'success' provider_state = provider_states.get(provider_name) return provider_state is not None and provider_state.get('status') == 'success' def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> Optional[List]: """Query a single provider with stop signal checking.""" provider_name = provider.get_name() start_time = datetime.now(timezone.utc) if self._is_stop_requested(): print(f"Stop requested before querying {provider_name} for {target}") return None print(f"Querying {provider_name} for {target}") self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}") try: if is_ip: results = provider.query_ip(target) else: results = provider.query_domain(target) if self._is_stop_requested(): print(f"Stop requested after querying {provider_name} for {target}") return None self._update_provider_state(target, provider_name, 'success', len(results), None, start_time) print(f"✓ {provider_name} returned {len(results)} results for {target}") return results except Exception as e: self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) print(f"✗ {provider_name} failed for {target}: {e}") return None def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: Optional[str], start_time: datetime) -> None: """Update provider state in node metadata for forensic tracking.""" if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] if 'metadata' not in node_data: node_data['metadata'] = {} if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} node_data['metadata']['provider_states'][provider_name] = { 'status': status, 'timestamp': start_time.isoformat(), 'results_count': results_count, 'error': error, 'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 } self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)") def _process_provider_results_forensic(self, target: str, provider, results: List, node_attributes: Dict, current_depth: int) -> Tuple[Set[str], bool]: """Process provider results, returns (discovered_targets, is_large_entity).""" provider_name = provider.get_name() discovered_targets = set() if self._is_stop_requested(): print(f"Stop requested before processing results from {provider_name} for {target}") return discovered_targets, False if len(results) > self.config.large_entity_threshold: print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}") members = self._create_large_entity(target, provider_name, results, current_depth) return members, True for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results): if i % 5 == 0 and self._is_stop_requested(): # Check more frequently print(f"Stop requested while processing results from {provider_name} for {target}") break self.logger.log_relationship_discovery( source_node=source, target_node=rel_target, relationship_type=rel_type, confidence_score=confidence, provider=provider_name, raw_data=raw_data, discovery_method=f"{provider_name}_query_depth_{current_depth}" ) self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source]) if isinstance(rel_target, list): # If the target is a list, iterate and process each item for single_target in rel_target: if _is_valid_ip(single_target): self.graph.add_node(single_target, NodeType.IP) if self.graph.add_edge(source, single_target, rel_type, confidence, provider_name, raw_data): print(f"Added IP relationship: {source} -> {single_target} ({rel_type})") discovered_targets.add(single_target) elif _is_valid_domain(single_target): self.graph.add_node(single_target, NodeType.DOMAIN) if self.graph.add_edge(source, single_target, rel_type, confidence, provider_name, raw_data): print(f"Added domain relationship: {source} -> {single_target} ({rel_type})") discovered_targets.add(single_target) self._collect_node_attributes(single_target, provider_name, rel_type, source, raw_data, node_attributes[single_target]) elif _is_valid_ip(rel_target): self.graph.add_node(rel_target, NodeType.IP) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added IP relationship: {source} -> {rel_target} ({rel_type})") discovered_targets.add(rel_target) elif rel_target.startswith('AS') and rel_target[2:].isdigit(): self.graph.add_node(rel_target, NodeType.ASN) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added ASN relationship: {source} -> {rel_target} ({rel_type})") elif _is_valid_domain(rel_target): self.graph.add_node(rel_target, NodeType.DOMAIN) if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): print(f"Added domain relationship: {source} -> {rel_target} ({rel_type})") discovered_targets.add(rel_target) self._collect_node_attributes(rel_target, provider_name, rel_type, source, raw_data, node_attributes[rel_target]) else: self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source]) return discovered_targets, False def _create_large_entity(self, source: str, provider_name: str, results: List, current_depth: int) -> Set[str]: """Create a large entity node and returns the members for DNS processing.""" entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" targets = [rel[1] for rel in results if len(rel) > 1] node_type = 'unknown' if targets: if _is_valid_domain(targets[0]): node_type = 'domain' elif _is_valid_ip(targets[0]): node_type = 'ip' # We still create the nodes so they exist in the graph, they are just not processed for edges yet. for target in targets: self.graph.add_node(target, NodeType.DOMAIN if node_type == 'domain' else NodeType.IP) attributes = { 'count': len(targets), 'nodes': targets, 'node_type': node_type, 'source_provider': provider_name, 'discovery_depth': current_depth, 'threshold_exceeded': self.config.large_entity_threshold, } description = f'Large entity created due to {len(targets)} results from {provider_name}' self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes, description=description) if results: rel_type = results[0][2] self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, {'large_entity_info': f'Contains {len(targets)} {node_type}s'}) self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}") return set(targets) def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool: """ Extracts a node from a large entity, re-creates its original edge, and re-queues it for full scanning. """ if not self.graph.graph.has_node(large_entity_id): print(f"ERROR: Large entity {large_entity_id} not found.") return False # 1. Get the original source node that discovered the large entity predecessors = list(self.graph.graph.predecessors(large_entity_id)) if not predecessors: print(f"ERROR: No source node found for large entity {large_entity_id}.") return False source_node_id = predecessors[0] # Get the original edge data to replicate it for the extracted node original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id) if not original_edge_data: print(f"ERROR: Could not find original edge data from {source_node_id} to {large_entity_id}.") return False # 2. Modify the graph data structure first success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract) if not success: print(f"ERROR: Node {node_id_to_extract} could not be removed from {large_entity_id}'s attributes.") return False # 3. Create the direct edge from the original source to the newly extracted node print(f"Re-creating direct edge from {source_node_id} to extracted node {node_id_to_extract}") self.graph.add_edge( source_id=source_node_id, target_id=node_id_to_extract, relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'), confidence_score=original_edge_data.get('confidence_score', 0.85), # Slightly lower confidence source_provider=original_edge_data.get('source_provider', 'unknown'), raw_data={'context': f'Extracted from large entity {large_entity_id}'} ) # 4. Re-queue the extracted node for full processing by all eligible providers print(f"Re-queueing extracted node {node_id_to_extract} for full reconnaissance...") is_ip = _is_valid_ip(node_id_to_extract) current_depth = self.graph.graph.nodes[large_entity_id].get('attributes', {}).get('discovery_depth', 0) eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) for provider in eligible_providers: provider_name = provider.get_name() self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth))) self.total_tasks_ever_enqueued += 1 # 5. If the scanner is not running, we need to kickstart it to process this one item. if self.status != ScanStatus.RUNNING: print("Scanner is idle. Starting a mini-scan to process the extracted node.") self.status = ScanStatus.RUNNING self._update_session_state() if not self.scan_thread or not self.scan_thread.is_alive(): self.scan_thread = threading.Thread( target=self._execute_scan, args=(self.current_target, self.max_depth), daemon=True ) self.scan_thread.start() print(f"Successfully extracted and re-queued {node_id_to_extract} from {large_entity_id}.") return True def _collect_node_attributes(self, node_id: str, provider_name: str, rel_type: str, target: str, raw_data: Dict[str, Any], attributes: Dict[str, Any]) -> None: """Collect and organize attributes for a node.""" self.logger.logger.debug(f"Collecting attributes for {node_id} from {provider_name}: {rel_type}") if provider_name == 'dns': record_type = raw_data.get('query_type', 'UNKNOWN') value = raw_data.get('value', target) dns_entry = f"{record_type}: {value}" if dns_entry not in attributes.get('dns_records', []): attributes.setdefault('dns_records', []).append(dns_entry) elif provider_name == 'crtsh': if rel_type == "san_certificate": domain_certs = raw_data.get('domain_certificates', {}) if node_id in domain_certs: cert_summary = domain_certs[node_id] attributes['certificates'] = cert_summary if target not in attributes.get('related_domains_san', []): attributes.setdefault('related_domains_san', []).append(target) elif provider_name == 'shodan': shodan_attributes = attributes.setdefault('shodan', {}) for key, value in raw_data.items(): if key not in shodan_attributes or not shodan_attributes.get(key): shodan_attributes[key] = value if rel_type == "asn_membership": attributes['asn'] = { 'id': target, 'description': raw_data.get('org', ''), 'isp': raw_data.get('isp', ''), 'country': raw_data.get('country', '') } record_type_name = rel_type if record_type_name not in attributes: attributes[record_type_name] = [] if isinstance(target, list): attributes[record_type_name].extend(target) else: if target not in attributes[record_type_name]: attributes[record_type_name].append(target) def _log_target_processing_error(self, target: str, error: str) -> None: """Log target processing errors for forensic trail.""" self.logger.logger.error(f"Target processing failed for {target}: {error}") def _log_provider_error(self, target: str, provider_name: str, error: str) -> None: """Log provider query errors for forensic trail.""" self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") def _log_no_eligible_providers(self, target: str, is_ip: bool) -> None: """Log when no providers are eligible for a target.""" target_type = 'IP' if is_ip else 'domain' self.logger.logger.warning(f"No eligible providers for {target_type}: {target}") def _calculate_progress(self) -> float: """Calculate scan progress percentage based on task completion.""" if self.total_tasks_ever_enqueued == 0: return 0.0 return min(100.0, (self.indicators_completed / self.total_tasks_ever_enqueued) * 100) def get_graph_data(self) -> Dict[str, Any]: """Get current graph data for visualization.""" return self.graph.get_graph_data() def export_results(self) -> Dict[str, Any]: """Export complete scan results with forensic audit trail.""" graph_data = self.graph.export_json() audit_trail = self.logger.export_audit_trail() provider_stats = {} for provider in self.providers: provider_stats[provider.get_name()] = provider.get_statistics() export_data = { 'scan_metadata': { 'target_domain': self.current_target, 'max_depth': self.max_depth, 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, 'enabled_providers': list(provider_stats.keys()), 'session_id': self.session_id }, 'graph_data': graph_data, 'forensic_audit': audit_trail, 'provider_statistics': provider_stats, 'scan_summary': self.logger.get_forensic_summary() } return export_data def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: """Get statistics for all providers with forensic information.""" stats = {} for provider in self.providers: stats[provider.get_name()] = provider.get_statistics() return stats def get_provider_info(self) -> Dict[str, Dict[str, Any]]: """Get information about all available providers.""" info = {} provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" try: module = importlib.import_module(module_name) for attribute_name in dir(module): attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_class = attribute # Instantiate to get metadata, even if not fully configured temp_provider = provider_class(name=attribute_name, session_config=self.config) provider_name = temp_provider.get_name() # Find the actual provider instance if it exists, to get live stats live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) info[provider_name] = { 'display_name': temp_provider.get_display_name(), 'requires_api_key': temp_provider.requires_api_key(), 'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(), 'enabled': self.config.is_provider_enabled(provider_name), 'rate_limit': self.config.get_rate_limit(provider_name), } except Exception as e: print(f"✗ Failed to get info for provider from {filename}: {e}") traceback.print_exc() return info