dnscope/core/scanner.py
2025-09-20 16:52:05 +02:00

1506 lines
69 KiB
Python

# dnsrecon-reduced/core/scanner.py
import threading
import traceback
import os
import importlib
import redis
import time
import random # Imported for jitter
from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from queue import PriorityQueue
from datetime import datetime, timezone
from core.graph_manager import GraphManager, NodeType
from core.logger import get_forensic_logger, new_session
from core.provider_result import ProviderResult
from utils.helpers import _is_valid_ip, _is_valid_domain
from utils.export_manager import export_manager
from providers.base_provider import BaseProvider
from providers.correlation_provider import CorrelationProvider
from core.rate_limiter import GlobalRateLimiter
class ScanStatus:
"""Enumeration of scan statuses."""
IDLE = "idle"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
STOPPED = "stopped"
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
UNIFIED: Combines comprehensive features with improved display formatting.
FIXED: Enhanced threading object initialization to prevent None references.
"""
def __init__(self, session_config=None, socketio=None):
"""Initialize scanner with session-specific configuration."""
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
# FIXED: Initialize all threading objects first
self._initialize_threading_objects()
# Set socketio (but will be set to None for storage)
self.socketio = socketio
self.config = session_config
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager
self.initial_targets = set()
# Thread-safe processing tracking (from Document 1)
self.currently_processing = set()
# Display-friendly processing list (from Document 2)
self.currently_processing_display = []
# Scanning progress tracking
self.total_indicators_found = 0
self.indicators_processed = 0
self.indicators_completed = 0
self.tasks_re_enqueued = 0
self.tasks_skipped = 0 # BUGFIX: Initialize tasks_skipped
self.total_tasks_ever_enqueued = 0
self.current_indicator = ""
self.last_task_from_queue = None
# Concurrent processing configuration
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Initialize collections that will be recreated during unpickling
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
# Initialize providers with session config
self._initialize_providers()
# Initialize logger
self.logger = get_forensic_logger()
# Initialize global rate limiter
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc()
raise
def _initialize_threading_objects(self):
"""
FIXED: Initialize all threading objects with proper error handling.
This method can be called during both __init__ and __setstate__.
"""
self.stop_event = threading.Event()
self.processing_lock = threading.Lock()
self.status_logger_stop_event = threading.Event()
self.status_logger_thread = None
def _is_stop_requested(self) -> bool:
"""
Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries.
FIXED: Added None check for stop_event.
"""
# FIXED: Ensure stop_event exists before checking
if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set():
return True
if self.session_id:
try:
from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id)
except Exception as e:
# Fall back to local event if it exists
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
# Final fallback
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
def _set_stop_signal(self) -> None:
"""
Set stop signal both locally and in Redis.
FIXED: Added None check for stop_event.
"""
# FIXED: Ensure stop_event exists before setting
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.set()
if self.session_id:
try:
from core.session_manager import session_manager
session_manager.set_stop_signal(self.session_id)
except Exception as e:
pass
def __getstate__(self):
"""Prepare object for pickling by excluding unpicklable attributes."""
state = self.__dict__.copy()
unpicklable_attrs = [
'stop_event',
'scan_thread',
'executor',
'processing_lock',
'task_queue',
'rate_limiter',
'logger',
'status_logger_thread',
'status_logger_stop_event',
'socketio'
]
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
if 'providers' in state:
for provider in state['providers']:
if hasattr(provider, '_stop_event'):
provider._stop_event = None
return state
def __setstate__(self, state):
"""Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# FIXED: Ensure all threading objects are properly initialized
self._initialize_threading_objects()
# Re-initialize other objects
self.scan_thread = None
self.executor = None
self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger()
# FIXED: Initialize socketio as None but preserve ability to set it
if not hasattr(self, 'socketio'):
self.socketio = None
# Initialize missing attributes with defaults
if not hasattr(self, 'providers') or not self.providers:
self._initialize_providers()
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if not hasattr(self, 'currently_processing_display'):
self.currently_processing_display = []
if not hasattr(self, 'target_retries'):
self.target_retries = defaultdict(int)
if not hasattr(self, 'scan_failed_due_to_retries'):
self.scan_failed_due_to_retries = False
if not hasattr(self, 'initial_targets'):
self.initial_targets = set()
# Ensure providers have stop events
if hasattr(self, 'providers'):
for provider in self.providers:
if hasattr(provider, 'set_stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
def _ensure_threading_objects_exist(self):
"""
FIXED: Utility method to ensure threading objects exist before use.
Call this before any method that might use threading objects.
"""
if not hasattr(self, 'stop_event') or self.stop_event is None:
print("WARNING: Threading objects not initialized, recreating...")
self._initialize_threading_objects()
if not hasattr(self, 'processing_lock') or self.processing_lock is None:
self.processing_lock = threading.Lock()
if not hasattr(self, 'task_queue') or self.task_queue is None:
self.task_queue = PriorityQueue()
def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration."""
self.providers = []
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
print(f"Loading provider module: {module_name}")
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
# FIXED: Pass the 'name' argument during initialization
provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name()
print(f" Provider: {provider_name}")
print(f" Class: {provider_class.__name__}")
print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}")
print(f" Requires API key: {provider.requires_api_key()}")
if provider.requires_api_key():
api_key = self.config.get_api_key(provider_name)
print(f" API key present: {'Yes' if api_key else 'No'}")
if api_key:
print(f" API key preview: {api_key[:8]}...")
if self.config.is_provider_enabled(provider_name):
is_available = provider.is_available()
print(f" Available: {is_available}")
if is_available:
# FIXED: Ensure stop_event exists before setting it
if hasattr(self, 'stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph)
self.providers.append(provider)
print(f" ✓ Added to scanner")
else:
print(f" ✗ Not available - skipped")
else:
print(f" ✗ Disabled in config - skipped")
except Exception as e:
print(f" ERROR loading {module_name}: {e}")
traceback.print_exc()
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
print(f"Active providers: {[p.get_name() for p in self.providers]}")
print(f"Provider count: {len(self.providers)}")
print("=" * 50)
def _status_logger_thread(self):
"""Periodically prints a clean, formatted scan status to the terminal."""
HEADER = "\033[95m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
ENDC = "\033[0m"
BOLD = "\033[1m"
last_status_str = ""
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
while not (hasattr(self, 'status_logger_stop_event') and
self.status_logger_stop_event and
self.status_logger_stop_event.is_set()):
try:
# FIXED: Check if processing_lock exists before using
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
else:
in_flight_tasks = list(getattr(self, 'currently_processing', []))
status_str = (
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | "
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
f"Skipped: {self.tasks_skipped} | "
f"Rescheduled: {self.tasks_re_enqueued}"
)
if status_str != last_status_str:
print(f"\n{'-'*80}")
print(status_str)
if self.last_task_from_queue:
# Unpack the new time-based queue item
_, p, (pn, ti, d) = self.last_task_from_queue
print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}")
if in_flight_tasks:
print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}")
display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]]
print("\n".join(display_tasks))
if len(in_flight_tasks) > 3:
print(f" ... and {len(in_flight_tasks) - 3} more")
print(f"{'-'*80}")
last_status_str = status_str
except Exception:
pass
time.sleep(2)
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
"""
FIXED: Enhanced start_scan with proper threading object initialization and socketio management.
"""
# FIXED: Ensure threading objects exist before proceeding
self._ensure_threading_objects_exist()
if self.scan_thread and self.scan_thread.is_alive():
self.logger.logger.info("Stopping existing scan before starting new one")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# Clean up processing state
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# Clear task queue
if hasattr(self, 'task_queue') and self.task_queue:
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
# Shutdown executor
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except:
pass
finally:
self.executor = None
# Wait for scan thread to finish (with timeout)
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
self.status = ScanStatus.IDLE
# FIXED: Ensure stop_event exists before clearing
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.clear()
if self.session_id:
from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
# FIXED: Restore socketio connection if missing
if not hasattr(self, 'socketio') or not self.socketio:
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"✓ Restored socketio connection for scan start")
# FIXED: Safe cleanup with existence checks
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
self.target_retries.clear()
self.scan_failed_due_to_retries = False
self.tasks_skipped = 0
self.last_task_from_queue = None
self._update_session_state()
try:
if not hasattr(self, 'providers') or not self.providers:
self.logger.logger.error("No providers available for scanning")
return False
available_providers = [p for p in self.providers if p.is_available()]
if not available_providers:
self.logger.logger.error("No providers are currently available/configured")
return False
if clear_graph:
self.graph.clear()
self.initial_targets.clear()
if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
try:
node_data = self.graph.graph.nodes[force_rescan_target]
if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}")
except Exception as e:
self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}")
target = target.lower().strip()
if not target:
self.logger.logger.error("Empty target provided")
return False
from utils.helpers import is_valid_target
if not is_valid_target(target):
self.logger.logger.error(f"Invalid target format: {target}")
return False
self.current_target = target
self.initial_targets.add(self.current_target)
self.max_depth = max(1, min(5, max_depth)) # Clamp depth between 1-5
self.current_depth = 0
self.total_indicators_found = 0
self.indicators_processed = 0
self.indicators_completed = 0
self.tasks_re_enqueued = 0
self.total_tasks_ever_enqueued = 0
self.current_indicator = self.current_target
self._update_session_state()
self.logger = new_session()
try:
self.scan_thread = threading.Thread(
target=self._execute_scan,
args=(self.current_target, self.max_depth),
daemon=True,
name=f"ScanThread-{self.session_id or 'default'}"
)
self.scan_thread.start()
# FIXED: Ensure status_logger_stop_event exists before clearing
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread(
target=self._status_logger_thread,
daemon=True,
name=f"StatusLogger-{self.session_id or 'default'}"
)
self.status_logger_thread.start()
self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}")
return True
except Exception as e:
self.logger.logger.error(f"Error starting scan threads: {e}")
self.status = ScanStatus.FAILED
self._update_session_state()
return False
except Exception as e:
self.logger.logger.error(f"Error in scan startup: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self._update_session_state()
return False
def _get_priority(self, provider_name):
if provider_name == 'correlation':
return 100 # Highest priority number = lowest priority (runs last)
rate_limit = self.config.get_rate_limit(provider_name)
# Handle edge cases
if rate_limit <= 0:
return 90 # Very low priority for invalid/disabled providers
if provider_name == 'dns':
return 1 # DNS is fastest, should run first
elif provider_name == 'shodan':
return 3 # Shodan is medium speed, good priority
elif provider_name == 'crtsh':
return 5 # crt.sh is slower, lower priority
else:
# For any other providers, use rate limit as a guide
if rate_limit >= 100:
return 2 # High rate limit = high priority
elif rate_limit >= 50:
return 4 # Medium-high rate limit = medium-high priority
elif rate_limit >= 20:
return 6 # Medium rate limit = medium priority
elif rate_limit >= 5:
return 8 # Low rate limit = low priority
else:
return 10 # Very low rate limit = very low priority
def _execute_scan(self, target: str, max_depth: int) -> None:
"""
FIXED: Enhanced execute_scan with proper threading object handling.
"""
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
update_counter = 0 # Track updates for throttling
last_update_time = time.time()
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False)
# FIXED: Filter out correlation provider from initial providers
initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)]
for provider in initial_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
self.task_queue.put((time.time(), priority, (provider_name, target, 0)))
self.total_tasks_ever_enqueued += 1
try:
self.status = ScanStatus.RUNNING
self._update_session_state()
enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target, max_depth, enabled_providers)
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, node_type)
self._initialize_provider_states(target)
consecutive_empty_iterations = 0
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
# PHASE 1: Run all non-correlation providers
print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested():
queue_empty = self.task_queue.empty()
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break # Phase 1 complete
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
# Process tasks (same logic as before, but correlations are filtered out)
try:
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# Skip correlation tasks during Phase 1
if provider_name == 'correlation':
continue
# Check if task is ready to run
current_time = time.time()
if run_at > current_time:
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
time.sleep(min(0.5, run_at - current_time))
continue
except: # Queue is empty or timeout occurred
time.sleep(0.1)
continue
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
# Skip if already processed
task_tuple = (provider_name, target_item, depth)
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# Skip if depth exceeded
if depth > max_depth:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# Rate limiting with proper time-based deferral
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
defer_until = time.time() + 60
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
continue
# Thread-safe processing state management
processing_key = (provider_name, target_item)
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested():
break
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
try:
self.current_depth = depth
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested():
break
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
update_counter += 1
current_time = time.time()
if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0):
self._update_session_state()
last_update_time = current_time
update_counter = 0
if self._is_stop_requested():
break
if not success:
retry_key = (provider_name, target_item, depth)
self.target_retries[retry_key] += 1
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
retry_count = self.target_retries[retry_key]
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
retry_at = time.time() + backoff_delay
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
self.logger.logger.debug(f"Retrying {provider_name}:{target_item} in {backoff_delay:.1f}s (attempt {retry_count})")
else:
self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded")
else:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
# Enqueue new targets with proper depth tracking
if not self._is_stop_requested():
for new_target in new_targets:
is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
# FIXED: Filter out correlation providers in Phase 1
eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)]
for p_new in eligible_providers_new:
p_name_new = p_new.get_name()
new_depth = depth + 1
new_task_tuple = (p_name_new, new_target, new_depth)
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
new_priority = self._get_priority(p_name_new)
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1
else:
self.logger.logger.warning(f"Provider {provider_name} not found in active providers")
self.tasks_skipped += 1
self.indicators_completed += 1
finally:
# FIXED: Safe processing lock usage for cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
# PHASE 2: Run correlations on all discovered nodes
if not self._is_stop_requested():
print(f"\n=== PHASE 2: Running correlation analysis ===")
self._run_correlation_phase(max_depth, processed_tasks)
except Exception as e:
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
# Comprehensive cleanup (same as before)
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
if self._is_stop_requested():
self.status = ScanStatus.STOPPED
elif self.scan_failed_due_to_retries:
self.status = ScanStatus.FAILED
else:
self.status = ScanStatus.COMPLETED
# FIXED: Safe stop event handling
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0)
self._update_session_state()
self.logger.log_scan_complete()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception as e:
self.logger.logger.warning(f"Error shutting down executor: {e}")
finally:
self.executor = None
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
"""
PHASE 2: Run correlation analysis on all discovered nodes.
This ensures correlations run after all other providers have completed.
"""
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if not correlation_provider:
print("No correlation provider found - skipping correlation phase")
return
# Get all nodes from the graph for correlation analysis
all_nodes = list(self.graph.graph.nodes())
correlation_tasks = []
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
for node_id in all_nodes:
if self._is_stop_requested():
break
# Determine appropriate depth for correlation (use 0 for simplicity)
correlation_depth = 0
task_tuple = ('correlation', node_id, correlation_depth)
# Don't re-process already processed correlation tasks
if task_tuple not in processed_tasks:
priority = self._get_priority('correlation')
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
correlation_tasks.append(task_tuple)
self.total_tasks_ever_enqueued += 1
print(f"Enqueued {len(correlation_tasks)} correlation tasks")
# Process correlation tasks
consecutive_empty_iterations = 0
max_empty_iterations = 20 # Shorter timeout for correlation phase
while not self._is_stop_requested() and correlation_tasks:
queue_empty = self.task_queue.empty()
# FIXED: Safe processing check
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
try:
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# Only process correlation tasks in this phase
if provider_name != 'correlation':
continue
except:
time.sleep(0.1)
continue
task_tuple = (provider_name, target_item, depth)
# Skip if already processed
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
continue
processing_key = (provider_name, target_item)
# FIXED: Safe processing management
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested():
break
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
try:
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested():
break
# Process correlation task
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
if success:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
else:
# For correlations, don't retry - just mark as completed
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
finally:
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
# Rest of the methods remain the same but with similar threading object safety checks...
# I'll include the key methods that need fixes:
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive graph data for real-time updates."""
try:
# FIXED: Safe processing state access
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
else:
currently_processing_count = len(getattr(self, 'currently_processing', []))
currently_processing_list = list(getattr(self, 'currently_processing', []))
# FIXED: Always include complete graph data for real-time updates
graph_data = self.get_graph_data()
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph': graph_data, # FIXED: Always include complete graph data
'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception as e:
traceback.print_exc()
return {
'status': 'error',
'message': 'Failed to get status',
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
def _update_session_state(self) -> None:
"""
FIXED: Update the scanner state in Redis and emit real-time WebSocket updates.
Enhanced with better error handling and socketio management.
"""
if self.session_id:
try:
# Get current status for WebSocket emission
current_status = self.get_scan_status()
# FIXED: Emit real-time update via WebSocket with better error handling
socketio_available = False
if hasattr(self, 'socketio') and self.socketio:
try:
print(f"📡 EMITTING WebSocket update: {current_status.get('status', 'unknown')} - "
f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, "
f"Edges: {len(current_status.get('graph', {}).get('edges', []))}")
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted successfully")
socketio_available = True
except Exception as ws_error:
print(f"⚠️ WebSocket emission failed: {ws_error}")
# Try to get socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print("🔄 Attempting to use registered socketio connection...")
registered_socketio.emit('scan_update', current_status)
self.socketio = registered_socketio # Update our reference
print("✅ WebSocket update emitted via registered connection")
socketio_available = True
else:
print("⚠️ No registered socketio connection found")
except Exception as fallback_error:
print(f"⚠️ Fallback socketio emission also failed: {fallback_error}")
else:
# Try to restore socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print(f"🔄 Restoring socketio connection for session {self.session_id}")
self.socketio = registered_socketio
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted via restored connection")
socketio_available = True
else:
print(f"⚠️ No socketio connection available for session {self.session_id}")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio connection: {restore_error}")
if not socketio_available:
print(f"⚠️ Real-time updates unavailable for session {self.session_id}")
# Update session in Redis for persistence (always do this)
try:
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception as redis_error:
print(f"⚠️ Failed to update session in Redis: {redis_error}")
except Exception as e:
print(f"⚠️ Error updating session state: {e}")
import traceback
traceback.print_exc()
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
"""
FIXED: Manages the entire process for a given target and provider with enhanced real-time updates.
"""
if self._is_stop_requested():
return set(), set(), False
is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, target_type)
self._initialize_provider_states(target)
new_targets = set()
provider_successful = True
try:
provider_result = self._execute_provider_query(provider, target, is_ip)
if provider_result is None:
provider_successful = False
elif not self._is_stop_requested():
discovered, is_large_entity = self._process_provider_result_unified(
target, provider, provider_result, depth
)
new_targets.update(discovered)
# FIXED: Emit real-time update after processing provider result
if discovered or provider_result.get_relationship_count() > 0:
# Ensure we have socketio connection for real-time updates
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio connection during provider processing")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio during provider processing: {restore_error}")
self._update_session_state()
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
return new_targets, set(), provider_successful
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
"""The "worker" function that directly communicates with the provider to fetch data."""
provider_name = provider.get_name()
start_time = datetime.now(timezone.utc)
if self._is_stop_requested():
return None
try:
if is_ip:
result = provider.query_ip(target)
else:
result = provider.query_domain(target)
if self._is_stop_requested():
return None
relationship_count = result.get_relationship_count() if result else 0
self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time)
return result
except Exception as e:
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
return None
def _create_large_entity_from_result(self, source_node: str, provider_name: str,
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
"""Creates a large entity node, tags all member nodes, and returns its ID and members."""
members = {rel.target_node for rel in provider_result.relationships
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
if not members:
return "", set()
large_entity_id = f"le_{provider_name}_{source_node}"
self.graph.add_node(
node_id=large_entity_id,
node_type=NodeType.LARGE_ENTITY,
attributes=[
{"name": "count", "value": len(members), "type": "statistic"},
{"name": "source_provider", "value": provider_name, "type": "metadata"},
{"name": "discovery_depth", "value": depth, "type": "metadata"},
{"name": "nodes", "value": list(members), "type": "metadata"}
],
description=f"A collection of {len(members)} nodes discovered from {source_node} via {provider_name}."
)
for member_id in members:
node_type = NodeType.IP if _is_valid_ip(member_id) else NodeType.DOMAIN
self.graph.add_node(
node_id=member_id,
node_type=node_type,
metadata={'large_entity_id': large_entity_id}
)
return large_entity_id, members
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
"""
FIXED: Removes a node from a large entity with immediate real-time update.
"""
if not self.graph.graph.has_node(node_id):
return False
node_data = self.graph.graph.nodes[node_id]
metadata = node_data.get('metadata', {})
if metadata.get('large_entity_id') == large_entity_id:
# Remove the large entity tag
del metadata['large_entity_id']
self.graph.add_node(node_id, NodeType(node_data['type']), metadata=metadata)
# Re-enqueue the node for full processing
is_ip = _is_valid_ip(node_id)
eligible_providers = self._get_eligible_providers(node_id, is_ip, False)
for provider in eligible_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
depth = 0
if self.graph.graph.has_node(large_entity_id):
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
depth_attr = next((a for a in le_attrs if a['name'] == 'discovery_depth'), None)
if depth_attr:
depth = depth_attr['value']
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
self.total_tasks_ever_enqueued += 1
# FIXED: Emit real-time update after extraction with socketio management
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for node extraction update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for extraction: {restore_error}")
self._update_session_state()
return True
return False
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
"""
FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates.
"""
provider_name = provider.get_name()
discovered_targets = set()
large_entity_id = ""
large_entity_members = set()
if self._is_stop_requested():
return discovered_targets, False
eligible_rel_count = sum(
1 for rel in provider_result.relationships if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)
)
is_large_entity = eligible_rel_count > self.config.large_entity_threshold
if is_large_entity:
large_entity_id, large_entity_members = self._create_large_entity_from_result(
target, provider_name, provider_result, current_depth
)
# Track if we added anything significant
nodes_added = 0
edges_added = 0
for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested():
break
source_node_id = relationship.source_node
target_node_id = relationship.target_node
# Determine visual source and target, substituting with large entity ID if necessary
visual_source = large_entity_id if source_node_id in large_entity_members else source_node_id
visual_target = large_entity_id if target_node_id in large_entity_members else target_node_id
# Prevent self-loops on the large entity node
if visual_source == visual_target:
continue
# Determine node types for the actual nodes
source_type = NodeType.IP if _is_valid_ip(source_node_id) else NodeType.DOMAIN
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
target_type = NodeType.ISP
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
target_type = NodeType.CA
elif provider_name == 'correlation':
target_type = NodeType.CORRELATION_OBJECT
elif _is_valid_ip(target_node_id):
target_type = NodeType.IP
else:
target_type = NodeType.DOMAIN
max_depth_reached = current_depth >= self.max_depth
# Add actual nodes to the graph (they might be hidden by the UI)
if self.graph.add_node(source_node_id, source_type):
nodes_added += 1
if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}):
nodes_added += 1
# Add the visual edge to the graph
if self.graph.add_edge(
visual_source, visual_target,
relationship.relationship_type,
relationship.confidence,
provider_name,
relationship.raw_data
):
edges_added += 1
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
if target_node_id not in large_entity_members:
discovered_targets.add(target_node_id)
if large_entity_members:
self.logger.logger.info(f"Enqueuing DNS and Correlation for {len(large_entity_members)} members of {large_entity_id}")
for member in large_entity_members:
for provider_name_to_run in ['dns', 'correlation']:
p_instance = next((p for p in self.providers if p.get_name() == provider_name_to_run), None)
if p_instance and p_instance.get_eligibility().get('domains' if _is_valid_domain(member) else 'ips'):
priority = self._get_priority(provider_name_to_run)
self.task_queue.put((time.time(), priority, (provider_name_to_run, member, current_depth)))
self.total_tasks_ever_enqueued += 1
attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes:
attr_dict = {
"name": attribute.name, "value": attribute.value, "type": attribute.type,
"provider": attribute.provider, "confidence": attribute.confidence, "metadata": attribute.metadata
}
attributes_by_node[attribute.target_node].append(attr_dict)
for node_id, node_attributes_list in attributes_by_node.items():
if not self.graph.graph.has_node(node_id):
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
nodes_added += 1
else:
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
# FIXED: Emit real-time update if we added anything significant
if nodes_added > 0 or edges_added > 0:
print(f"🔄 Added {nodes_added} nodes, {edges_added} edges - triggering real-time update")
# Ensure we have socketio connection for immediate update
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for immediate update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for immediate update: {restore_error}")
self._update_session_state()
return discovered_targets, is_large_entity
def _initialize_provider_states(self, target: str) -> None:
"""FIXED: Safer provider state initialization with error handling."""
try:
if not self.graph.graph.has_node(target):
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
except Exception as e:
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
"""FIXED: Improved provider eligibility checking with better filtering."""
if dns_only:
return [p for p in self.providers if p.get_name() == 'dns']
eligible = []
target_key = 'ips' if is_ip else 'domains'
# Check if the target is part of a large entity
is_in_large_entity = False
if self.graph.graph.has_node(target):
metadata = self.graph.graph.nodes[target].get('metadata', {})
if 'large_entity_id' in metadata:
is_in_large_entity = True
for provider in self.providers:
try:
# If in large entity, only allow dns and correlation providers
if is_in_large_entity and provider.get_name() not in ['dns', 'correlation']:
continue
# Check if provider supports this target type
if not provider.get_eligibility().get(target_key, False):
continue
# Check if provider is available/configured
if not provider.is_available():
continue
# Check if we already successfully queried this provider
if not self._already_queried_provider(target, provider.get_name()):
eligible.append(provider)
except Exception as e:
self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}")
continue
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""FIXED: More robust check for already queried providers with proper error handling."""
try:
if not self.graph.graph.has_node(target):
return False
node_data = self.graph.graph.nodes[target]
provider_states = node_data.get('metadata', {}).get('provider_states', {})
provider_state = provider_states.get(provider_name)
# Only consider it already queried if it was successful
return (provider_state is not None and
provider_state.get('status') == 'success' and
provider_state.get('results_count', 0) > 0)
except Exception as e:
self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}")
return False
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None:
"""FIXED: More robust provider state updates with validation."""
try:
if not self.graph.graph.has_node(target):
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
# Calculate duration safely
try:
duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
except Exception:
duration_ms = 0
node_data['metadata']['provider_states'][provider_name] = {
'status': status,
'timestamp': start_time.isoformat(),
'results_count': max(0, results_count), # Ensure non-negative
'error': str(error) if error else None,
'duration_ms': duration_ms
}
# Update last modified time for forensic integrity
if hasattr(self, 'last_modified'):
self.last_modified = datetime.now(timezone.utc).isoformat()
except Exception as e:
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
def _log_target_processing_error(self, target: str, error: str) -> None:
self.logger.logger.error(f"Target processing failed for {target}: {error}")
def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _calculate_progress(self) -> float:
try:
if self.total_tasks_ever_enqueued == 0:
return 0.0
# Add small buffer for tasks still in queue to avoid showing 100% too early
queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0)
# FIXED: Safe processing count
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
active_tasks = len(self.currently_processing)
else:
active_tasks = len(getattr(self, 'currently_processing', []))
# Adjust total to account for remaining work
adjusted_total = max(self.total_tasks_ever_enqueued,
self.indicators_completed + queue_size + active_tasks)
if adjusted_total == 0:
return 100.0
progress = (self.indicators_completed / adjusted_total) * 100
return max(0.0, min(100.0, progress)) # Clamp between 0 and 100
except Exception as e:
self.logger.logger.warning(f"Error calculating progress: {e}")
return 0.0
def get_graph_data(self) -> Dict[str, Any]:
"""Get current graph data formatted for frontend visualization."""
graph_data = self.graph.get_graph_data()
graph_data['initial_targets'] = list(self.initial_targets)
return graph_data
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
"""Get comprehensive information about all available providers."""
info = {}
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
temp_provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = temp_provider.get_name()
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
info[provider_name] = {
'display_name': temp_provider.get_display_name(),
'requires_api_key': temp_provider.requires_api_key(),
'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
'enabled': self.config.is_provider_enabled(provider_name),
'rate_limit': self.config.get_rate_limit(provider_name),
}
except Exception:
traceback.print_exc()
return info