# dnsrecon-reduced/core/scanner.py import threading import traceback import os import importlib import redis import time import math import random # Imported for jitter from typing import List, Set, Dict, Any, Tuple, Optional from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from queue import PriorityQueue from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType from core.logger import get_forensic_logger, new_session from core.provider_result import ProviderResult from utils.helpers import _is_valid_ip, _is_valid_domain from utils.export_manager import export_manager from providers.base_provider import BaseProvider from providers.correlation_provider import CorrelationProvider from core.rate_limiter import GlobalRateLimiter class ScanStatus: """Enumeration of scan statuses.""" IDLE = "idle" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" STOPPED = "stopped" class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. UNIFIED: Combines comprehensive features with improved display formatting. """ def __init__(self, session_config=None): """Initialize scanner with session-specific configuration.""" try: # Use provided session config or create default if session_config is None: from core.session_config import create_session_config session_config = create_session_config() self.config = session_config self.graph = GraphManager() self.providers = [] self.status = ScanStatus.IDLE self.current_target = None self.current_depth = 0 self.max_depth = 2 self.stop_event = threading.Event() self.scan_thread = None self.session_id: Optional[str] = None # Will be set by session manager self.task_queue = PriorityQueue() self.target_retries = defaultdict(int) self.scan_failed_due_to_retries = False self.initial_targets = set() # Thread-safe processing tracking (from Document 1) self.currently_processing = set() self.processing_lock = threading.Lock() # Display-friendly processing list (from Document 2) self.currently_processing_display = [] # Scanning progress tracking self.total_indicators_found = 0 self.indicators_processed = 0 self.indicators_completed = 0 self.tasks_re_enqueued = 0 self.tasks_skipped = 0 # BUGFIX: Initialize tasks_skipped self.total_tasks_ever_enqueued = 0 self.current_indicator = "" self.last_task_from_queue = None # Concurrent processing configuration self.max_workers = self.config.max_concurrent_requests self.executor = None # Status logger thread with improved formatting self.status_logger_thread = None self.status_logger_stop_event = threading.Event() # Initialize providers with session config self._initialize_providers() # Initialize logger self.logger = get_forensic_logger() # Initialize global rate limiter self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) except Exception as e: print(f"ERROR: Scanner initialization failed: {e}") traceback.print_exc() raise def _is_stop_requested(self) -> bool: """ Check if stop is requested using both local and Redis-based signals. This ensures reliable termination across process boundaries. """ if self.stop_event.is_set(): return True if self.session_id: try: from core.session_manager import session_manager return session_manager.is_stop_requested(self.session_id) except Exception as e: # Fall back to local event return self.stop_event.is_set() return self.stop_event.is_set() def _set_stop_signal(self) -> None: """ Set stop signal both locally and in Redis. """ self.stop_event.set() if self.session_id: try: from core.session_manager import session_manager session_manager.set_stop_signal(self.session_id) except Exception as e: pass def __getstate__(self): """Prepare object for pickling by excluding unpicklable attributes.""" state = self.__dict__.copy() unpicklable_attrs = [ 'stop_event', 'scan_thread', 'executor', 'processing_lock', 'task_queue', 'rate_limiter', 'logger', 'status_logger_thread', 'status_logger_stop_event' ] for attr in unpicklable_attrs: if attr in state: del state[attr] if 'providers' in state: for provider in state['providers']: if hasattr(provider, '_stop_event'): provider._stop_event = None return state def __setstate__(self, state): """Restore object after unpickling by reconstructing threading objects.""" self.__dict__.update(state) self.stop_event = threading.Event() self.scan_thread = None self.executor = None self.processing_lock = threading.Lock() self.task_queue = PriorityQueue() self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) self.logger = get_forensic_logger() self.status_logger_thread = None self.status_logger_stop_event = threading.Event() if not hasattr(self, 'providers') or not self.providers: self._initialize_providers() if not hasattr(self, 'currently_processing'): self.currently_processing = set() if not hasattr(self, 'currently_processing_display'): self.currently_processing_display = [] if hasattr(self, 'providers'): for provider in self.providers: if hasattr(provider, 'set_stop_event'): provider.set_stop_event(self.stop_event) def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===") for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" try: print(f"Loading provider module: {module_name}") module = importlib.import_module(module_name) for attribute_name in dir(module): attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_class = attribute # FIXED: Pass the 'name' argument during initialization provider = provider_class(name=attribute_name, session_config=self.config) provider_name = provider.get_name() print(f" Provider: {provider_name}") print(f" Class: {provider_class.__name__}") print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}") print(f" Requires API key: {provider.requires_api_key()}") if provider.requires_api_key(): api_key = self.config.get_api_key(provider_name) print(f" API key present: {'Yes' if api_key else 'No'}") if api_key: print(f" API key preview: {api_key[:8]}...") if self.config.is_provider_enabled(provider_name): is_available = provider.is_available() print(f" Available: {is_available}") if is_available: provider.set_stop_event(self.stop_event) if isinstance(provider, CorrelationProvider): provider.set_graph_manager(self.graph) self.providers.append(provider) print(f" ✓ Added to scanner") else: print(f" ✗ Not available - skipped") else: print(f" ✗ Disabled in config - skipped") except Exception as e: print(f" ERROR loading {module_name}: {e}") traceback.print_exc() print(f"=== PROVIDER INITIALIZATION COMPLETE ===") print(f"Active providers: {[p.get_name() for p in self.providers]}") print(f"Provider count: {len(self.providers)}") print("=" * 50) def _status_logger_thread(self): """Periodically prints a clean, formatted scan status to the terminal.""" HEADER = "\033[95m" CYAN = "\033[96m" GREEN = "\033[92m" YELLOW = "\033[93m" BLUE = "\033[94m" ENDC = "\033[0m" BOLD = "\033[1m" last_status_str = "" while not self.status_logger_stop_event.is_set(): try: with self.processing_lock: in_flight_tasks = list(self.currently_processing) self.currently_processing_display = in_flight_tasks.copy() status_str = ( f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | " f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | " f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | " f"{GREEN}Completed: {self.indicators_completed}{ENDC} | " f"Skipped: {self.tasks_skipped} | " f"Rescheduled: {self.tasks_re_enqueued}" ) if status_str != last_status_str: print(f"\n{'-'*80}") print(status_str) if self.last_task_from_queue: # Unpack the new time-based queue item _, p, (pn, ti, d) = self.last_task_from_queue print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}") if in_flight_tasks: print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}") display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]] print("\n".join(display_tasks)) if len(in_flight_tasks) > 3: print(f" ... and {len(in_flight_tasks) - 3} more") print(f"{'-'*80}") last_status_str = status_str except Exception: pass time.sleep(2) def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: if self.scan_thread and self.scan_thread.is_alive(): self.logger.logger.info("Stopping existing scan before starting new one") self._set_stop_signal() self.status = ScanStatus.STOPPED # Clean up processing state with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] # Clear task queue while not self.task_queue.empty(): try: self.task_queue.get_nowait() except: break # Shutdown executor if self.executor: try: self.executor.shutdown(wait=False, cancel_futures=True) except: pass finally: self.executor = None # Wait for scan thread to finish (with timeout) self.scan_thread.join(timeout=5.0) if self.scan_thread.is_alive(): self.logger.logger.warning("Previous scan thread did not terminate cleanly") self.status = ScanStatus.IDLE self.stop_event.clear() if self.session_id: from core.session_manager import session_manager session_manager.clear_stop_signal(self.session_id) with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] self.task_queue = PriorityQueue() self.target_retries.clear() self.scan_failed_due_to_retries = False self.tasks_skipped = 0 self.last_task_from_queue = None self._update_session_state() try: if not hasattr(self, 'providers') or not self.providers: self.logger.logger.error("No providers available for scanning") return False available_providers = [p for p in self.providers if p.is_available()] if not available_providers: self.logger.logger.error("No providers are currently available/configured") return False if clear_graph: self.graph.clear() self.initial_targets.clear() if force_rescan_target and self.graph.graph.has_node(force_rescan_target): try: node_data = self.graph.graph.nodes[force_rescan_target] if 'metadata' in node_data and 'provider_states' in node_data['metadata']: node_data['metadata']['provider_states'] = {} self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}") except Exception as e: self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}") target = target.lower().strip() if not target: self.logger.logger.error("Empty target provided") return False from utils.helpers import is_valid_target if not is_valid_target(target): self.logger.logger.error(f"Invalid target format: {target}") return False self.current_target = target self.initial_targets.add(self.current_target) self.max_depth = max(1, min(5, max_depth)) # Clamp depth between 1-5 self.current_depth = 0 self.total_indicators_found = 0 self.indicators_processed = 0 self.indicators_completed = 0 self.tasks_re_enqueued = 0 self.total_tasks_ever_enqueued = 0 self.current_indicator = self.current_target self._update_session_state() self.logger = new_session() try: self.scan_thread = threading.Thread( target=self._execute_scan, args=(self.current_target, self.max_depth), daemon=True, name=f"ScanThread-{self.session_id or 'default'}" ) self.scan_thread.start() self.status_logger_stop_event.clear() self.status_logger_thread = threading.Thread( target=self._status_logger_thread, daemon=True, name=f"StatusLogger-{self.session_id or 'default'}" ) self.status_logger_thread.start() self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}") return True except Exception as e: self.logger.logger.error(f"Error starting scan threads: {e}") self.status = ScanStatus.FAILED self._update_session_state() return False except Exception as e: self.logger.logger.error(f"Error in scan startup: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self._update_session_state() return False def _get_priority(self, provider_name): if provider_name == 'correlation': return 100 # Highest priority number = lowest priority (runs last) rate_limit = self.config.get_rate_limit(provider_name) # Handle edge cases if rate_limit <= 0: return 90 # Very low priority for invalid/disabled providers if provider_name == 'dns': return 1 # DNS is fastest, should run first elif provider_name == 'shodan': return 3 # Shodan is medium speed, good priority elif provider_name == 'crtsh': return 5 # crt.sh is slower, lower priority else: # For any other providers, use rate limit as a guide if rate_limit >= 100: return 2 # High rate limit = high priority elif rate_limit >= 50: return 4 # Medium-high rate limit = medium-high priority elif rate_limit >= 20: return 6 # Medium rate limit = medium priority elif rate_limit >= 5: return 8 # Low rate limit = low priority else: return 10 # Very low rate limit = very low priority def _execute_scan(self, target: str, max_depth: int) -> None: self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping is_ip = _is_valid_ip(target) initial_providers = self._get_eligible_providers(target, is_ip, False) # FIXED: Filter out correlation provider from initial providers initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)] for provider in initial_providers: provider_name = provider.get_name() priority = self._get_priority(provider_name) self.task_queue.put((time.time(), priority, (provider_name, target, 0))) self.total_tasks_ever_enqueued += 1 try: self.status = ScanStatus.RUNNING self._update_session_state() enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target, max_depth, enabled_providers) node_type = NodeType.IP if is_ip else NodeType.DOMAIN self.graph.add_node(target, node_type) self._initialize_provider_states(target) consecutive_empty_iterations = 0 max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion # PHASE 1: Run all non-correlation providers print(f"\n=== PHASE 1: Running non-correlation providers ===") while not self._is_stop_requested(): queue_empty = self.task_queue.empty() with self.processing_lock: no_active_processing = len(self.currently_processing) == 0 if queue_empty and no_active_processing: consecutive_empty_iterations += 1 if consecutive_empty_iterations >= max_empty_iterations: break # Phase 1 complete time.sleep(0.1) continue else: consecutive_empty_iterations = 0 # Process tasks (same logic as before, but correlations are filtered out) try: run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) # Skip correlation tasks during Phase 1 if provider_name == 'correlation': continue # Check if task is ready to run current_time = time.time() if run_at > current_time: self.task_queue.put((run_at, priority, (provider_name, target_item, depth))) time.sleep(min(0.5, run_at - current_time)) continue except: # Queue is empty or timeout occurred time.sleep(0.1) continue self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) # Skip if already processed task_tuple = (provider_name, target_item, depth) if task_tuple in processed_tasks: self.tasks_skipped += 1 self.indicators_completed += 1 continue # Skip if depth exceeded if depth > max_depth: self.tasks_skipped += 1 self.indicators_completed += 1 continue # Rate limiting with proper time-based deferral if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): defer_until = time.time() + 60 self.task_queue.put((defer_until, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 continue # Thread-safe processing state management with self.processing_lock: if self._is_stop_requested(): break processing_key = (provider_name, target_item) if processing_key in self.currently_processing: self.tasks_skipped += 1 self.indicators_completed += 1 continue self.currently_processing.add(processing_key) try: self.current_depth = depth self.current_indicator = target_item self._update_session_state() if self._is_stop_requested(): break provider = next((p for p in self.providers if p.get_name() == provider_name), None) if provider and not isinstance(provider, CorrelationProvider): new_targets, _, success = self._process_provider_task(provider, target_item, depth) if self._is_stop_requested(): break if not success: retry_key = (provider_name, target_item, depth) self.target_retries[retry_key] += 1 if self.target_retries[retry_key] <= self.config.max_retries_per_target: retry_count = self.target_retries[retry_key] backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) retry_at = time.time() + backoff_delay self.task_queue.put((retry_at, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 self.logger.logger.debug(f"Retrying {provider_name}:{target_item} in {backoff_delay:.1f}s (attempt {retry_count})") else: self.scan_failed_due_to_retries = True self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded") else: processed_tasks.add(task_tuple) self.indicators_completed += 1 # Enqueue new targets with proper depth tracking if not self._is_stop_requested(): for new_target in new_targets: is_ip_new = _is_valid_ip(new_target) eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) # FIXED: Filter out correlation providers in Phase 1 eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)] for p_new in eligible_providers_new: p_name_new = p_new.get_name() new_depth = depth + 1 new_task_tuple = (p_name_new, new_target, new_depth) if new_task_tuple not in processed_tasks and new_depth <= max_depth: new_priority = self._get_priority(p_name_new) self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth))) self.total_tasks_ever_enqueued += 1 else: self.logger.logger.warning(f"Provider {provider_name} not found in active providers") self.tasks_skipped += 1 self.indicators_completed += 1 finally: with self.processing_lock: processing_key = (provider_name, target_item) self.currently_processing.discard(processing_key) # PHASE 2: Run correlations on all discovered nodes if not self._is_stop_requested(): print(f"\n=== PHASE 2: Running correlation analysis ===") self._run_correlation_phase(max_depth, processed_tasks) except Exception as e: traceback.print_exc() self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: # Comprehensive cleanup (same as before) with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] while not self.task_queue.empty(): try: self.task_queue.get_nowait() except: break if self._is_stop_requested(): self.status = ScanStatus.STOPPED elif self.scan_failed_due_to_retries: self.status = ScanStatus.FAILED else: self.status = ScanStatus.COMPLETED self.status_logger_stop_event.set() if self.status_logger_thread and self.status_logger_thread.is_alive(): self.status_logger_thread.join(timeout=2.0) self._update_session_state() self.logger.log_scan_complete() if self.executor: try: self.executor.shutdown(wait=False, cancel_futures=True) except Exception as e: self.logger.logger.warning(f"Error shutting down executor: {e}") finally: self.executor = None def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None: """ PHASE 2: Run correlation analysis on all discovered nodes. This ensures correlations run after all other providers have completed. """ correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None) if not correlation_provider: print("No correlation provider found - skipping correlation phase") return # Get all nodes from the graph for correlation analysis all_nodes = list(self.graph.graph.nodes()) correlation_tasks = [] print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes") for node_id in all_nodes: if self._is_stop_requested(): break # Determine appropriate depth for correlation (use 0 for simplicity) correlation_depth = 0 task_tuple = ('correlation', node_id, correlation_depth) # Don't re-process already processed correlation tasks if task_tuple not in processed_tasks: priority = self._get_priority('correlation') self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth))) correlation_tasks.append(task_tuple) self.total_tasks_ever_enqueued += 1 print(f"Enqueued {len(correlation_tasks)} correlation tasks") # Process correlation tasks consecutive_empty_iterations = 0 max_empty_iterations = 20 # Shorter timeout for correlation phase while not self._is_stop_requested() and correlation_tasks: queue_empty = self.task_queue.empty() with self.processing_lock: no_active_processing = len(self.currently_processing) == 0 if queue_empty and no_active_processing: consecutive_empty_iterations += 1 if consecutive_empty_iterations >= max_empty_iterations: break time.sleep(0.1) continue else: consecutive_empty_iterations = 0 try: run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) # Only process correlation tasks in this phase if provider_name != 'correlation': continue except: time.sleep(0.1) continue task_tuple = (provider_name, target_item, depth) # Skip if already processed if task_tuple in processed_tasks: self.tasks_skipped += 1 self.indicators_completed += 1 if task_tuple in correlation_tasks: correlation_tasks.remove(task_tuple) continue with self.processing_lock: if self._is_stop_requested(): break processing_key = (provider_name, target_item) if processing_key in self.currently_processing: self.tasks_skipped += 1 self.indicators_completed += 1 continue self.currently_processing.add(processing_key) try: self.current_indicator = target_item self._update_session_state() if self._is_stop_requested(): break # Process correlation task new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth) if success: processed_tasks.add(task_tuple) self.indicators_completed += 1 if task_tuple in correlation_tasks: correlation_tasks.remove(task_tuple) else: # For correlations, don't retry - just mark as completed self.indicators_completed += 1 if task_tuple in correlation_tasks: correlation_tasks.remove(task_tuple) finally: with self.processing_lock: processing_key = (provider_name, target_item) self.currently_processing.discard(processing_key) print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}") def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: """ Manages the entire process for a given target and provider. This version is generalized to handle all relationships dynamically. """ if self._is_stop_requested(): return set(), set(), False is_ip = _is_valid_ip(target) target_type = NodeType.IP if is_ip else NodeType.DOMAIN self.graph.add_node(target, target_type) self._initialize_provider_states(target) new_targets = set() provider_successful = True try: provider_result = self._execute_provider_query(provider, target, is_ip) if provider_result is None: provider_successful = False elif not self._is_stop_requested(): # Pass all relationships to be processed discovered, is_large_entity = self._process_provider_result_unified( target, provider, provider_result, depth ) new_targets.update(discovered) except Exception as e: provider_successful = False self._log_provider_error(target, provider.get_name(), str(e)) return new_targets, set(), provider_successful def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]: """ The "worker" function that directly communicates with the provider to fetch data. """ provider_name = provider.get_name() start_time = datetime.now(timezone.utc) if self._is_stop_requested(): return None try: if is_ip: result = provider.query_ip(target) else: result = provider.query_domain(target) if self._is_stop_requested(): return None relationship_count = result.get_relationship_count() if result else 0 self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time) return result except Exception as e: self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) return None def _create_large_entity_from_result(self, source_node: str, provider_name: str, provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]: """ Creates a large entity node, tags all member nodes, and returns its ID and members. """ members = {rel.target_node for rel in provider_result.relationships if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)} if not members: return "", set() large_entity_id = f"le_{provider_name}_{source_node}" self.graph.add_node( node_id=large_entity_id, node_type=NodeType.LARGE_ENTITY, attributes=[ {"name": "count", "value": len(members), "type": "statistic"}, {"name": "source_provider", "value": provider_name, "type": "metadata"}, {"name": "discovery_depth", "value": depth, "type": "metadata"}, {"name": "nodes", "value": list(members), "type": "metadata"} ], description=f"A collection of {len(members)} nodes discovered from {source_node} via {provider_name}." ) for member_id in members: node_type = NodeType.IP if _is_valid_ip(member_id) else NodeType.DOMAIN self.graph.add_node( node_id=member_id, node_type=node_type, metadata={'large_entity_id': large_entity_id} ) return large_entity_id, members def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool: """ Removes a node from a large entity, allowing it to be processed normally. """ if not self.graph.graph.has_node(node_id): return False node_data = self.graph.graph.nodes[node_id] metadata = node_data.get('metadata', {}) if metadata.get('large_entity_id') == large_entity_id: # Remove the large entity tag del metadata['large_entity_id'] self.graph.add_node(node_id, NodeType(node_data['type']), metadata=metadata) # Re-enqueue the node for full processing is_ip = _is_valid_ip(node_id) eligible_providers = self._get_eligible_providers(node_id, is_ip, False) for provider in eligible_providers: provider_name = provider.get_name() priority = self._get_priority(provider_name) # Use current depth of the large entity if available, else 0 depth = 0 if self.graph.graph.has_node(large_entity_id): le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', []) depth_attr = next((a for a in le_attrs if a['name'] == 'discovery_depth'), None) if depth_attr: depth = depth_attr['value'] self.task_queue.put((time.time(), priority, (provider_name, node_id, depth))) self.total_tasks_ever_enqueued += 1 return True return False def _process_provider_result_unified(self, target: str, provider: BaseProvider, provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: """ Process a unified ProviderResult object to update the graph. This version dynamically re-routes edges to a large entity container. """ provider_name = provider.get_name() discovered_targets = set() large_entity_id = "" large_entity_members = set() if self._is_stop_requested(): return discovered_targets, False eligible_rel_count = sum( 1 for rel in provider_result.relationships if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node) ) is_large_entity = eligible_rel_count > self.config.large_entity_threshold if is_large_entity: large_entity_id, large_entity_members = self._create_large_entity_from_result( target, provider_name, provider_result, current_depth ) for i, relationship in enumerate(provider_result.relationships): if i % 5 == 0 and self._is_stop_requested(): break source_node_id = relationship.source_node target_node_id = relationship.target_node # Determine visual source and target, substituting with large entity ID if necessary visual_source = large_entity_id if source_node_id in large_entity_members else source_node_id visual_target = large_entity_id if target_node_id in large_entity_members else target_node_id # Prevent self-loops on the large entity node if visual_source == visual_target: continue # Determine node types for the actual nodes source_type = NodeType.IP if _is_valid_ip(source_node_id) else NodeType.DOMAIN if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp': target_type = NodeType.ISP elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer': target_type = NodeType.CA elif provider_name == 'correlation': target_type = NodeType.CORRELATION_OBJECT elif _is_valid_ip(target_node_id): target_type = NodeType.IP else: target_type = NodeType.DOMAIN max_depth_reached = current_depth >= self.max_depth # Add actual nodes to the graph (they might be hidden by the UI) self.graph.add_node(source_node_id, source_type) self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}) # Add the visual edge to the graph self.graph.add_edge( visual_source, visual_target, relationship.relationship_type, relationship.confidence, provider_name, relationship.raw_data ) if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached: if target_node_id not in large_entity_members: discovered_targets.add(target_node_id) if large_entity_members: self.logger.logger.info(f"Enqueuing DNS and Correlation for {len(large_entity_members)} members of {large_entity_id}") for member in large_entity_members: for provider_name_to_run in ['dns', 'correlation']: p_instance = next((p for p in self.providers if p.get_name() == provider_name_to_run), None) if p_instance and p_instance.get_eligibility().get('domains' if _is_valid_domain(member) else 'ips'): priority = self._get_priority(provider_name_to_run) self.task_queue.put((time.time(), priority, (provider_name_to_run, member, current_depth))) self.total_tasks_ever_enqueued += 1 attributes_by_node = defaultdict(list) for attribute in provider_result.attributes: attr_dict = { "name": attribute.name, "value": attribute.value, "type": attribute.type, "provider": attribute.provider, "confidence": attribute.confidence, "metadata": attribute.metadata } attributes_by_node[attribute.target_node].append(attr_dict) for node_id, node_attributes_list in attributes_by_node.items(): if not self.graph.graph.has_node(node_id): node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN self.graph.add_node(node_id, node_type, attributes=node_attributes_list) else: existing_attrs = self.graph.graph.nodes[node_id].get('attributes', []) self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list return discovered_targets, is_large_entity def stop_scan(self) -> bool: """Request immediate scan termination with proper cleanup.""" try: self.logger.logger.info("Scan termination requested by user") self._set_stop_signal() self.status = ScanStatus.STOPPED with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] self.task_queue = PriorityQueue() if self.executor: try: self.executor.shutdown(wait=False, cancel_futures=True) except Exception: pass self._update_session_state() return True except Exception as e: self.logger.logger.error(f"Error during scan termination: {e}") traceback.print_exc() return False def _update_session_state(self) -> None: """ Update the scanner state in Redis for GUI updates. """ if self.session_id: try: from core.session_manager import session_manager session_manager.update_session_scanner(self.session_id, self) except Exception: pass def get_scan_status(self) -> Dict[str, Any]: """Get current scan status with comprehensive processing information.""" try: with self.processing_lock: currently_processing_count = len(self.currently_processing) currently_processing_list = list(self.currently_processing) return { 'status': self.status, 'target_domain': self.current_target, 'current_depth': self.current_depth, 'max_depth': self.max_depth, 'current_indicator': self.current_indicator, 'indicators_processed': self.indicators_processed, 'indicators_completed': self.indicators_completed, 'tasks_re_enqueued': self.tasks_re_enqueued, 'progress_percentage': self._calculate_progress(), 'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued, 'enabled_providers': [provider.get_name() for provider in self.providers], 'graph_statistics': self.graph.get_statistics(), 'task_queue_size': self.task_queue.qsize(), 'currently_processing_count': currently_processing_count, 'currently_processing': currently_processing_list[:5], 'tasks_in_queue': self.task_queue.qsize(), 'tasks_completed': self.indicators_completed, 'tasks_skipped': self.tasks_skipped, 'tasks_rescheduled': self.tasks_re_enqueued, } except Exception: traceback.print_exc() return { 'status': 'error', 'message': 'Failed to get status' } def _initialize_provider_states(self, target: str) -> None: """ FIXED: Safer provider state initialization with error handling. """ try: if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] if 'metadata' not in node_data: node_data['metadata'] = {} if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} except Exception as e: self.logger.logger.warning(f"Error initializing provider states for {target}: {e}") def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: """ FIXED: Improved provider eligibility checking with better filtering. """ if dns_only: return [p for p in self.providers if p.get_name() == 'dns'] eligible = [] target_key = 'ips' if is_ip else 'domains' # Check if the target is part of a large entity is_in_large_entity = False if self.graph.graph.has_node(target): metadata = self.graph.graph.nodes[target].get('metadata', {}) if 'large_entity_id' in metadata: is_in_large_entity = True for provider in self.providers: try: # If in large entity, only allow dns and correlation providers if is_in_large_entity and provider.get_name() not in ['dns', 'correlation']: continue # Check if provider supports this target type if not provider.get_eligibility().get(target_key, False): continue # Check if provider is available/configured if not provider.is_available(): continue # Check if we already successfully queried this provider if not self._already_queried_provider(target, provider.get_name()): eligible.append(provider) except Exception as e: self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}") continue return eligible def _already_queried_provider(self, target: str, provider_name: str) -> bool: """ FIXED: More robust check for already queried providers with proper error handling. """ try: if not self.graph.graph.has_node(target): return False node_data = self.graph.graph.nodes[target] provider_states = node_data.get('metadata', {}).get('provider_states', {}) provider_state = provider_states.get(provider_name) # Only consider it already queried if it was successful return (provider_state is not None and provider_state.get('status') == 'success' and provider_state.get('results_count', 0) > 0) except Exception as e: self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}") return False def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: Optional[str], start_time: datetime) -> None: """ FIXED: More robust provider state updates with validation. """ try: if not self.graph.graph.has_node(target): self.logger.logger.warning(f"Cannot update provider state: node {target} not found") return node_data = self.graph.graph.nodes[target] if 'metadata' not in node_data: node_data['metadata'] = {} if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} # Calculate duration safely try: duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 except Exception: duration_ms = 0 node_data['metadata']['provider_states'][provider_name] = { 'status': status, 'timestamp': start_time.isoformat(), 'results_count': max(0, results_count), # Ensure non-negative 'error': str(error) if error else None, 'duration_ms': duration_ms } # Update last modified time for forensic integrity self.last_modified = datetime.now(timezone.utc).isoformat() except Exception as e: self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}") def _log_target_processing_error(self, target: str, error: str) -> None: self.logger.logger.error(f"Target processing failed for {target}: {error}") def _log_provider_error(self, target: str, provider_name: str, error: str) -> None: self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") def _calculate_progress(self) -> float: try: if self.total_tasks_ever_enqueued == 0: return 0.0 # Add small buffer for tasks still in queue to avoid showing 100% too early queue_size = max(0, self.task_queue.qsize()) with self.processing_lock: active_tasks = len(self.currently_processing) # Adjust total to account for remaining work adjusted_total = max(self.total_tasks_ever_enqueued, self.indicators_completed + queue_size + active_tasks) if adjusted_total == 0: return 100.0 progress = (self.indicators_completed / adjusted_total) * 100 return max(0.0, min(100.0, progress)) # Clamp between 0 and 100 except Exception as e: self.logger.logger.warning(f"Error calculating progress: {e}") return 0.0 def get_graph_data(self) -> Dict[str, Any]: graph_data = self.graph.get_graph_data() graph_data['initial_targets'] = list(self.initial_targets) return graph_data def get_provider_info(self) -> Dict[str, Dict[str, Any]]: info = {} provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" try: module = importlib.import_module(module_name) for attribute_name in dir(module): attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_class = attribute temp_provider = provider_class(name=attribute_name, session_config=self.config) provider_name = temp_provider.get_name() live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) info[provider_name] = { 'display_name': temp_provider.get_display_name(), 'requires_api_key': temp_provider.requires_api_key(), 'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(), 'enabled': self.config.is_provider_enabled(provider_name), 'rate_limit': self.config.get_rate_limit(provider_name), } except Exception: traceback.print_exc() return info