From 4378146d0cb76828af16d63d2a87892641ddace9 Mon Sep 17 00:00:00 2001 From: overcuriousity Date: Sun, 14 Sep 2025 13:14:02 +0200 Subject: [PATCH] it --- .gitignore | 1 + app.py | 167 +++++--- core/__init__.py | 6 +- core/scanner.py | 813 +++++++++++++++++-------------------- core/session_config.py | 282 ++++++++++++- core/session_manager.py | 390 ++++++++++++------ core/task_manager.py | 564 +++++++++++++++++++++++++ providers/base_provider.py | 225 +++++++--- 8 files changed, 1765 insertions(+), 683 deletions(-) create mode 100644 core/task_manager.py diff --git a/.gitignore b/.gitignore index ef8188a..6afb256 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ cython_debug/ #.idea/ dump.rdb +.vscode \ No newline at end of file diff --git a/app.py b/app.py index e3e89be..cf7bd8b 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,6 @@ """ Flask application entry point for DNSRecon web interface. -Provides REST API endpoints and serves the web interface with user session support. +Enhanced with user session management and task-based completion model. """ import json @@ -9,7 +9,7 @@ from flask import Flask, render_template, request, jsonify, send_file, session from datetime import datetime, timezone, timedelta import io -from core.session_manager import session_manager +from core.session_manager import session_manager, UserIdentifier from config import config @@ -17,46 +17,73 @@ app = Flask(__name__) app.config['SECRET_KEY'] = 'dnsrecon-dev-key-change-in-production' app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2) # 2 hour session lifetime + def get_user_scanner(): """ - User scanner retrieval with better error handling and debugging. + Enhanced user scanner retrieval with user identification and session consolidation. + Implements single session per user with seamless consolidation. """ - # Get current Flask session info for debugging - current_flask_session_id = session.get('dnsrecon_session_id') - client_ip = request.remote_addr - user_agent = request.headers.get('User-Agent', '')[:100] # Truncate for logging + print("=== ENHANCED GET_USER_SCANNER ===") - # Try to get existing session - if current_flask_session_id: - existing_scanner = session_manager.get_session(current_flask_session_id) - if existing_scanner: - # Ensure session ID is set - existing_scanner.session_id = current_flask_session_id - return current_flask_session_id, existing_scanner - else: - print(f"Session {current_flask_session_id} not found in session manager") - - # Create new session - print("Creating new session...") - new_session_id = session_manager.create_session() - new_scanner = session_manager.get_session(new_session_id) - - if not new_scanner: - print(f"ERROR: Failed to retrieve newly created session {new_session_id}") - raise Exception("Failed to create new scanner session") - - # Store in Flask session - session['dnsrecon_session_id'] = new_session_id - session.permanent = True - - # Ensure session ID is set on scanner - new_scanner.session_id = new_session_id - - print(f"Created new session: {new_session_id}") - print(f"New scanner status: {new_scanner.status}") - print("=== END SESSION DEBUG ===") - - return new_session_id, new_scanner + try: + # Extract user identification from request + client_ip, user_agent = UserIdentifier.extract_request_info(request) + user_fingerprint = UserIdentifier.generate_user_fingerprint(client_ip, user_agent) + + print(f"User fingerprint: {user_fingerprint}") + print(f"Client IP: {client_ip}") + print(f"User Agent: {user_agent[:50]}...") + + # Get current Flask session info for debugging + current_flask_session_id = session.get('dnsrecon_session_id') + print(f"Flask session ID: {current_flask_session_id}") + + # Try to get existing session first + if current_flask_session_id: + existing_scanner = session_manager.get_session(current_flask_session_id) + if existing_scanner: + # Verify session belongs to current user + session_info = session_manager.get_session_info(current_flask_session_id) + if session_info.get('user_fingerprint') == user_fingerprint: + print(f"Found valid existing session {current_flask_session_id} for user {user_fingerprint}") + existing_scanner.session_id = current_flask_session_id + return current_flask_session_id, existing_scanner + else: + print(f"Session {current_flask_session_id} belongs to different user, will create new session") + else: + print(f"Session {current_flask_session_id} not found in Redis, will create new session") + + # Create or replace user session (this handles consolidation automatically) + new_session_id = session_manager.create_or_replace_user_session(client_ip, user_agent) + new_scanner = session_manager.get_session(new_session_id) + + if not new_scanner: + print(f"ERROR: Failed to retrieve newly created session {new_session_id}") + raise Exception("Failed to create new scanner session") + + # Store in Flask session for browser persistence + session['dnsrecon_session_id'] = new_session_id + session.permanent = True + + # Ensure session ID is set on scanner + new_scanner.session_id = new_session_id + + # Get session info for user feedback + session_info = session_manager.get_session_info(new_session_id) + + print(f"Session created/consolidated successfully") + print(f" - Session ID: {new_session_id}") + print(f" - User: {user_fingerprint}") + print(f" - Scanner status: {new_scanner.status}") + print(f" - Session age: {session_info.get('session_age_minutes', 0)} minutes") + + return new_session_id, new_scanner + + except Exception as e: + print(f"ERROR: Exception in get_user_scanner: {e}") + traceback.print_exc() + raise + @app.route('/') def index(): @@ -67,7 +94,7 @@ def index(): @app.route('/api/scan/start', methods=['POST']) def start_scan(): """ - Start a new reconnaissance scan with immediate GUI feedback. + Start a new reconnaissance scan with enhanced user session management. """ print("=== API: /api/scan/start called ===") @@ -87,7 +114,7 @@ def start_scan(): max_depth = data.get('max_depth', config.default_recursion_depth) clear_graph = data.get('clear_graph', True) - print(f"Parsed - target_domain: '{target_domain}', max_depth: {max_depth}") + print(f"Parsed - target_domain: '{target_domain}', max_depth: {max_depth}, clear_graph: {clear_graph}") # Validation if not target_domain: @@ -106,7 +133,7 @@ def start_scan(): print("Validation passed, getting user scanner...") - # Get user-specific scanner + # Get user-specific scanner with enhanced session management user_session_id, scanner = get_user_scanner() # Ensure session ID is properly set @@ -126,12 +153,21 @@ def start_scan(): if success: scan_session_id = scanner.logger.session_id print(f"Scan started successfully with scan session ID: {scan_session_id}") + + # Get session info for user feedback + session_info = session_manager.get_session_info(user_session_id) + return jsonify({ 'success': True, 'message': 'Scan started successfully', 'scan_id': scan_session_id, 'user_session_id': user_session_id, 'scanner_status': scanner.status, + 'session_info': { + 'user_fingerprint': session_info.get('user_fingerprint', 'unknown'), + 'session_age_minutes': session_info.get('session_age_minutes', 0), + 'consolidated': session_info.get('session_age_minutes', 0) > 0 + }, 'debug_info': { 'scanner_object_id': id(scanner), 'scanner_status': scanner.status @@ -216,7 +252,7 @@ def stop_scan(): @app.route('/api/scan/status', methods=['GET']) def get_scan_status(): - """Get current scan status with error handling.""" + """Get current scan status with enhanced session information.""" try: # Get user-specific scanner user_session_id, scanner = get_user_scanner() @@ -247,6 +283,15 @@ def get_scan_status(): status = scanner.get_scan_status() status['user_session_id'] = user_session_id + # Add enhanced session information + session_info = session_manager.get_session_info(user_session_id) + status['session_info'] = { + 'user_fingerprint': session_info.get('user_fingerprint', 'unknown'), + 'session_age_minutes': session_info.get('session_age_minutes', 0), + 'client_ip': session_info.get('client_ip', 'unknown'), + 'last_activity': session_info.get('last_activity') + } + # Additional debug info status['debug_info'] = { 'scanner_object_id': id(scanner), @@ -275,7 +320,6 @@ def get_scan_status(): }), 500 - @app.route('/api/graph', methods=['GET']) def get_graph_data(): """Get current graph data with error handling.""" @@ -321,7 +365,6 @@ def get_graph_data(): }), 500 - @app.route('/api/export', methods=['GET']) def export_results(): """Export complete scan results as downloadable JSON for the user session.""" @@ -332,17 +375,22 @@ def export_results(): # Get complete results results = scanner.export_results() - # Add session information to export + # Add enhanced session information to export + session_info = session_manager.get_session_info(user_session_id) results['export_metadata'] = { 'user_session_id': user_session_id, + 'user_fingerprint': session_info.get('user_fingerprint', 'unknown'), + 'client_ip': session_info.get('client_ip', 'unknown'), + 'session_age_minutes': session_info.get('session_age_minutes', 0), 'export_timestamp': datetime.now(timezone.utc).isoformat(), 'export_type': 'user_session_results' } - # Create filename with timestamp + # Create filename with user fingerprint timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S') target = scanner.current_target or 'unknown' - filename = f"dnsrecon_{target}_{timestamp}_{user_session_id[:8]}.json" + user_fp = session_info.get('user_fingerprint', 'unknown')[:8] + filename = f"dnsrecon_{target}_{timestamp}_{user_fp}.json" # Create in-memory file json_data = json.dumps(results, indent=2, ensure_ascii=False) @@ -450,7 +498,7 @@ def set_api_keys(): @app.route('/api/session/info', methods=['GET']) def get_session_info(): - """Get information about the current user session.""" + """Get enhanced information about the current user session.""" try: user_session_id, scanner = get_user_scanner() session_info = session_manager.get_session_info(user_session_id) @@ -501,7 +549,7 @@ def terminate_session(): @app.route('/api/admin/sessions', methods=['GET']) def list_sessions(): - """Admin endpoint to list all active sessions.""" + """Admin endpoint to list all active sessions with enhanced information.""" try: sessions = session_manager.list_active_sessions() stats = session_manager.get_statistics() @@ -523,7 +571,7 @@ def list_sessions(): @app.route('/api/health', methods=['GET']) def health_check(): - """Health check endpoint.""" + """Health check endpoint with enhanced session statistics.""" try: # Get session stats session_stats = session_manager.get_statistics() @@ -532,8 +580,8 @@ def health_check(): 'success': True, 'status': 'healthy', 'timestamp': datetime.now(timezone.utc).isoformat(), - 'version': '1.0.0-phase2', - 'phase': 2, + 'version': '2.0.0-enhanced', + 'phase': 'enhanced_architecture', 'features': { 'multi_provider': True, 'concurrent_processing': True, @@ -542,9 +590,18 @@ def health_check(): 'visualization': True, 'retry_logic': True, 'user_sessions': True, - 'session_isolation': True + 'session_isolation': True, + 'global_provider_caching': True, + 'single_session_per_user': True, + 'session_consolidation': True, + 'task_completion_model': True }, - 'session_statistics': session_stats + 'session_statistics': session_stats, + 'cache_info': { + 'global_provider_cache': True, + 'cache_location': '.cache//', + 'cache_expiry_hours': 12 + } }) except Exception as e: print(f"ERROR: Exception in health_check endpoint: {e}") @@ -575,7 +632,7 @@ def internal_error(error): if __name__ == '__main__': - print("Starting DNSRecon Flask application with user session support...") + print("Starting DNSRecon Flask application with enhanced user session support...") # Load configuration from environment config.load_from_env() diff --git a/core/__init__.py b/core/__init__.py index 2c23f1d..691dde5 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -8,6 +8,7 @@ from .scanner import Scanner, ScanStatus from .logger import ForensicLogger, get_forensic_logger, new_session from .session_manager import session_manager from .session_config import SessionConfig, create_session_config +from .task_manager import TaskManager, TaskType, ReconTask __all__ = [ 'GraphManager', @@ -19,7 +20,10 @@ __all__ = [ 'new_session', 'session_manager', 'SessionConfig', - 'create_session_config' + 'create_session_config', + 'TaskManager', + 'TaskType', + 'ReconTask' ] __version__ = "1.0.0-phase2" \ No newline at end of file diff --git a/core/scanner.py b/core/scanner.py index 23406e3..3a81e81 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -12,6 +12,7 @@ from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType from core.logger import get_forensic_logger, new_session +from core.task_manager import TaskManager, TaskType, ReconTask from utils.helpers import _is_valid_ip, _is_valid_domain from providers.base_provider import BaseProvider @@ -27,12 +28,13 @@ class ScanStatus: class Scanner: """ - Main scanning orchestrator for DNSRecon passive reconnaissance. + Enhanced scanning orchestrator for DNSRecon passive reconnaissance. + Now uses task-based completion model with comprehensive retry logic. """ def __init__(self, session_config=None): - """Initialize scanner with session-specific configuration.""" - print("Initializing Scanner instance...") + """Initialize scanner with session-specific configuration and task management.""" + print("Initializing Enhanced Scanner instance...") try: # Use provided session config or create default @@ -50,16 +52,18 @@ class Scanner: self.stop_event = threading.Event() self.scan_thread = None self.session_id = None # Will be set by session manager - self.current_scan_id = None # NEW: Track current scan ID + self.current_scan_id = None # Track current scan ID - # Scanning progress tracking + # Task-based execution components + self.task_manager = None # Will be initialized when needed + self.max_workers = self.config.max_concurrent_requests + + # Enhanced progress tracking self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = "" - - # Concurrent processing configuration - self.max_workers = self.config.max_concurrent_requests - self.executor = None + self.scan_start_time = None + self.scan_end_time = None # Initialize providers with session config print("Calling _initialize_providers with session config...") @@ -69,17 +73,55 @@ class Scanner: print("Initializing forensic logger...") self.logger = get_forensic_logger() - print("Scanner initialization complete") + print("Enhanced Scanner initialization complete") except Exception as e: - print(f"ERROR: Scanner initialization failed: {e}") + print(f"ERROR: Enhanced Scanner initialization failed: {e}") traceback.print_exc() raise + def __getstate__(self): + """Prepare object for pickling by excluding unpicklable attributes.""" + state = self.__dict__.copy() + + # Remove unpicklable threading objects + unpicklable_attrs = [ + 'stop_event', + 'scan_thread', + 'task_manager' + ] + + for attr in unpicklable_attrs: + if attr in state: + del state[attr] + + # Handle providers separately to ensure they're picklable + if 'providers' in state: + # The providers should be picklable now, but let's ensure clean state + for provider in state['providers']: + if hasattr(provider, '_stop_event'): + provider._stop_event = None + + return state + + def __setstate__(self, state): + """Restore object after unpickling by reconstructing threading objects.""" + self.__dict__.update(state) + + # Reconstruct threading objects + self.stop_event = threading.Event() + self.scan_thread = None + self.task_manager = None + + # Re-set stop events for providers + if hasattr(self, 'providers'): + for provider in self.providers: + if hasattr(provider, 'set_stop_event'): + provider.set_stop_event(self.stop_event) + def _is_stop_requested(self) -> bool: """ - Check if stop is requested using both local and Redis-based signals. - This ensures reliable termination across process boundaries. + Enhanced stop signal checking that handles both local and Redis-based signals. """ # Check local threading event first (fastest) if self.stop_event.is_set(): @@ -112,86 +154,86 @@ class Scanner: except Exception as e: print(f"Error setting Redis stop signal: {e}") - def __getstate__(self): - """Prepare object for pickling by excluding unpicklable attributes.""" - state = self.__dict__.copy() - - # Remove unpicklable threading objects - unpicklable_attrs = [ - 'stop_event', - 'scan_thread', - 'executor' - ] - - for attr in unpicklable_attrs: - if attr in state: - del state[attr] - - # Handle providers separately to ensure they're picklable - if 'providers' in state: - # The providers should be picklable now, but let's ensure clean state - for provider in state['providers']: - if hasattr(provider, '_stop_event'): - provider._stop_event = None - - return state - - def __setstate__(self, state): - """Restore object after unpickling by reconstructing threading objects.""" - self.__dict__.update(state) - - # Reconstruct threading objects - self.stop_event = threading.Event() - self.scan_thread = None - self.executor = None - - # Re-set stop events for providers - if hasattr(self, 'providers'): - for provider in self.providers: - if hasattr(provider, 'set_stop_event'): - provider.set_stop_event(self.stop_event) - def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] print("Initializing providers with session config...") provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') - for filename in os.listdir(provider_dir): - if filename.endswith('_provider.py') and not filename.startswith('base'): - module_name = f"providers.{filename[:-3]}" - try: - module = importlib.import_module(module_name) - for attribute_name in dir(module): - attribute = getattr(module, attribute_name) - if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: - provider_class = attribute - provider_name = provider_class(session_config=self.config).get_name() - if self.config.is_provider_enabled(provider_name): - provider = provider_class(session_config=self.config) - if provider.is_available(): - provider.set_stop_event(self.stop_event) - self.providers.append(provider) - print(f"✓ {provider.get_display_name()} provider initialized successfully for session") - else: - print(f"✗ {provider.get_display_name()} provider is not available") - except Exception as e: - print(f"✗ Failed to initialize provider from {filename}: {e}") - traceback.print_exc() + print(f"Looking for providers in: {provider_dir}") + + if not os.path.exists(provider_dir): + print(f"ERROR: Provider directory does not exist: {provider_dir}") + return + + provider_files = [f for f in os.listdir(provider_dir) if f.endswith('_provider.py') and not f.startswith('base')] + print(f"Found provider files: {provider_files}") + + for filename in provider_files: + module_name = f"providers.{filename[:-3]}" + print(f"Attempting to load module: {module_name}") + + try: + module = importlib.import_module(module_name) + print(f" ✓ Module {module_name} loaded successfully") + + # Find provider classes in the module + provider_classes_found = [] + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: + provider_classes_found.append((attribute_name, attribute)) + + print(f" Found provider classes: {[name for name, _ in provider_classes_found]}") + + for class_name, provider_class in provider_classes_found: + try: + # Create temporary instance to get provider name + temp_provider = provider_class(session_config=self.config) + provider_name = temp_provider.get_name() + print(f" Provider {class_name} -> name: {provider_name}") + + # Check if enabled in config + is_enabled = self.config.is_provider_enabled(provider_name) + print(f" Provider {provider_name} enabled: {is_enabled}") + + if is_enabled: + # Check if available (has API keys, etc.) + is_available = temp_provider.is_available() + print(f" Provider {provider_name} available: {is_available}") + + if is_available: + # Set stop event and add to providers list + temp_provider.set_stop_event(self.stop_event) + self.providers.append(temp_provider) + print(f" ✓ {temp_provider.get_display_name()} provider initialized successfully") + else: + print(f" - {temp_provider.get_display_name()} provider is not available (missing API key or other requirement)") + else: + print(f" - {temp_provider.get_display_name()} provider is disabled in config") + + except Exception as e: + print(f" ✗ Failed to initialize provider class {class_name}: {e}") + import traceback + traceback.print_exc() + + except Exception as e: + print(f" ✗ Failed to load module {module_name}: {e}") + import traceback + traceback.print_exc() - print(f"Initialized {len(self.providers)} providers for session") - - def update_session_config(self, new_config) -> None: - """Update session configuration and reinitialize providers.""" - print("Updating session configuration...") - self.config = new_config - self.max_workers = self.config.max_concurrent_requests - self._initialize_providers() - print("Session configuration updated") + print(f"Total providers initialized: {len(self.providers)}") + for provider in self.providers: + print(f" - {provider.get_display_name()} ({provider.get_name()})") + + if len(self.providers) == 0: + print("WARNING: No providers were initialized!") + elif len(self.providers) == 1 and self.providers[0].get_name() == 'dns': + print("WARNING: Only DNS provider initialized - other providers may have failed to load") def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: - """Start a new reconnaissance scan with immediate GUI feedback.""" - print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") + """Start a new reconnaissance scan with task-based completion model.""" + print(f"=== STARTING ENHANCED SCAN IN SCANNER {id(self)} ===") print(f"Session ID: {self.session_id}") print(f"Initial scanner status: {self.status}") print(f"Clear graph: {clear_graph}") @@ -209,8 +251,8 @@ class Scanner: print("Terminating previous scan thread...") self._set_stop_signal() - if self.executor: - self.executor.shutdown(wait=False, cancel_futures=True) + if self.task_manager: + self.task_manager.stop_execution() self.scan_thread.join(timeout=8.0) if self.scan_thread.is_alive(): @@ -256,6 +298,8 @@ class Scanner: self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = self.current_target + self.scan_start_time = datetime.now(timezone.utc) + self.scan_end_time = None self._update_session_state() @@ -263,33 +307,29 @@ class Scanner: if clear_graph: self.logger = new_session() - # Start scan thread (original behavior allows concurrent threads for "Add to Graph") - print(f"Starting scan thread with scan ID {self.current_scan_id}...") + # Start task-based scan thread + print(f"Starting task-based scan thread with scan ID {self.current_scan_id}...") self.scan_thread = threading.Thread( - target=self._execute_scan, + target=self._execute_task_based_scan, args=(self.current_target, max_depth, self.current_scan_id), daemon=True ) self.scan_thread.start() - print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") + print(f"=== ENHANCED SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") return True except Exception as e: print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() self.status = ScanStatus.FAILED + self.scan_end_time = datetime.now(timezone.utc) self._update_session_state() return False - def _execute_scan(self, target_domain: str, max_depth: int, scan_id: str) -> None: - """Execute the reconnaissance scan using a task queue-based approach.""" - print(f"_execute_scan started for {target_domain} with depth {max_depth}, scan ID {scan_id}") - - self.executor = ThreadPoolExecutor(max_workers=self.max_workers) - processed_targets = set() - - task_queue = deque([(target_domain, 0, False)]) # target, depth, is_large_entity_member + def _execute_task_based_scan(self, target_domain: str, max_depth: int, scan_id: str) -> None: + """Execute the reconnaissance scan using the task-based completion model.""" + print(f"_execute_task_based_scan started for {target_domain} with depth {max_depth}, scan ID {scan_id}") try: self.status = ScanStatus.RUNNING @@ -297,20 +337,38 @@ class Scanner: enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target_domain, max_depth, enabled_providers) + + # Initialize task manager + self.task_manager = TaskManager( + self.providers, + self.graph, + self.logger, + max_concurrent_tasks=self.max_workers + ) + + # Add initial target to graph self.graph.add_node(target_domain, NodeType.DOMAIN) - self._initialize_provider_states(target_domain) - - while task_queue: + + # Start task execution + self.task_manager.start_execution(max_workers=self.max_workers) + + # Track processed targets to avoid duplicates + processed_targets = set() + + # Task queue for breadth-first processing + target_queue = deque([(target_domain, 0)]) # (target, depth) + + while target_queue: # Abort if scan ID changed (new scan started) if self.current_scan_id != scan_id: print(f"Scan aborted - ID mismatch (current: {self.current_scan_id}, expected: {scan_id})") break if self._is_stop_requested(): - print("Stop requested, terminating scan.") + print("Stop requested, terminating task-based scan.") break - target, depth, is_large_entity_member = task_queue.popleft() + target, depth = target_queue.popleft() if target in processed_targets or depth > max_depth: continue @@ -319,98 +377,133 @@ class Scanner: self.current_indicator = target self._update_session_state() - new_targets, large_entity_members = self._query_providers_for_target(target, depth, is_large_entity_member) + print(f"Processing target: {target} at depth {depth}") + + # Create tasks for all eligible providers + task_ids = self.task_manager.create_provider_tasks(target, depth, self.providers) + + if task_ids: + print(f"Created {len(task_ids)} tasks for target {target}") + self.total_indicators_found += len(task_ids) + self._update_session_state() + processed_targets.add(target) - # Only add new targets if scan ID still matches (prevents stale updates) - if self.current_scan_id == scan_id: + # Wait for current batch of tasks to complete before processing next depth + # This ensures we get all relationships before expanding further + timeout_per_batch = 60 # 60 seconds per batch + batch_start = time.time() + + while time.time() - batch_start < timeout_per_batch: + if self._is_stop_requested() or self.current_scan_id != scan_id: + break + + progress_report = self.task_manager.get_progress_report() + stats = progress_report['statistics'] + + # Update progress tracking + self.indicators_processed = stats['succeeded'] + stats['failed_permanent'] + self._update_session_state() + + # Check if current batch is complete + current_batch_complete = ( + stats['pending'] == 0 and + stats['running'] == 0 and + stats['failed_retrying'] == 0 + ) + + if current_batch_complete: + print(f"Batch complete for {target}: {stats['succeeded']} succeeded, {stats['failed_permanent']} failed") + break + + time.sleep(1.0) # Check every second + + # Collect new targets from completed successful tasks + if depth < max_depth: + new_targets = self._collect_new_targets_from_completed_tasks() for new_target in new_targets: if new_target not in processed_targets: - task_queue.append((new_target, depth + 1, False)) - - for member in large_entity_members: - if member not in processed_targets: - task_queue.append((member, depth, True)) - - except Exception as e: - print(f"ERROR: Scan execution failed with error: {e}") - traceback.print_exc() - self.status = ScanStatus.FAILED - self.logger.logger.error(f"Scan failed: {e}") - finally: - # Only update final status if scan ID still matches (prevents stale status updates) + target_queue.append((new_target, depth + 1)) + print(f"Added new target for next depth: {new_target}") + + # Wait for all remaining tasks to complete + print("Waiting for all tasks to complete...") + final_completion = self.task_manager.wait_for_completion(timeout_seconds=300) + + if not final_completion: + print("WARNING: Some tasks did not complete within timeout") + + # Final progress update + final_report = self.task_manager.get_progress_report() + final_stats = final_report['statistics'] + + print(f"Final task statistics:") + print(f" - Total tasks: {final_stats['total_tasks']}") + print(f" - Succeeded: {final_stats['succeeded']}") + print(f" - Failed permanently: {final_stats['failed_permanent']}") + print(f" - Completion rate: {final_stats['completion_rate']:.1f}%") + + # Determine final scan status if self.current_scan_id == scan_id: if self._is_stop_requested(): self.status = ScanStatus.STOPPED + elif final_stats['failed_permanent'] > 0 and final_stats['succeeded'] == 0: + self.status = ScanStatus.FAILED + elif final_stats['completion_rate'] < 50.0: # Less than 50% success rate + self.status = ScanStatus.FAILED else: self.status = ScanStatus.COMPLETED - + + self.scan_end_time = datetime.now(timezone.utc) self._update_session_state() self.logger.log_scan_complete() else: print(f"Scan completed but ID mismatch - not updating final status") - - if self.executor: - self.executor.shutdown(wait=False, cancel_futures=True) - stats = self.graph.get_statistics() + + except Exception as e: + print(f"ERROR: Task-based scan execution failed: {e}") + traceback.print_exc() + self.status = ScanStatus.FAILED + self.scan_end_time = datetime.now(timezone.utc) + self.logger.logger.error(f"Task-based scan failed: {e}") + finally: + # Clean up task manager + if self.task_manager: + self.task_manager.stop_execution() + + # Final statistics + graph_stats = self.graph.get_statistics() print("Final scan statistics:") - print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") - print(f" - Total edges: {stats['basic_metrics']['total_edges']}") + print(f" - Total nodes: {graph_stats['basic_metrics']['total_nodes']}") + print(f" - Total edges: {graph_stats['basic_metrics']['total_edges']}") print(f" - Targets processed: {len(processed_targets)}") - def _query_providers_for_target(self, target: str, depth: int, dns_only: bool = False) -> Tuple[Set[str], Set[str]]: - """Helper method to query providers for a single target.""" - is_ip = _is_valid_ip(target) - target_type = NodeType.IP if is_ip else NodeType.DOMAIN - print(f"Querying providers for {target_type.value}: {target} at depth {depth}") - - if self._is_stop_requested(): - print(f"Stop requested before querying providers for {target}") - return set(), set() - - self.graph.add_node(target, target_type) - self._initialize_provider_states(target) - + def _collect_new_targets_from_completed_tasks(self) -> Set[str]: + """Collect new targets from successfully completed tasks.""" new_targets = set() - large_entity_members = set() - node_attributes = defaultdict(lambda: defaultdict(list)) - - eligible_providers = self._get_eligible_providers(target, is_ip, dns_only) - if not eligible_providers: - self._log_no_eligible_providers(target, is_ip) - return new_targets, large_entity_members - - for provider in eligible_providers: - if self._is_stop_requested(): - print(f"Stop requested while querying providers for {target}") - break - - try: - provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth) - if provider_results and not self._is_stop_requested(): - discovered, is_large_entity = self._process_provider_results_forensic( - target, provider, provider_results, node_attributes, depth - ) - if is_large_entity: - large_entity_members.update(discovered) - else: - new_targets.update(discovered) - except Exception as e: - self._log_provider_error(target, provider.get_name(), str(e)) - - for node_id, attributes in node_attributes.items(): - if self.graph.graph.has_node(node_id): - node_is_ip = _is_valid_ip(node_id) - node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN - self.graph.add_node(node_id, node_type_to_add, attributes=attributes) - - return new_targets, large_entity_members + if not self.task_manager: + return new_targets + + # Get task summaries to find successful tasks + task_summaries = self.task_manager.task_queue.get_task_summaries() + + for task_summary in task_summaries: + if task_summary['status'] == 'succeeded': + task_id = task_summary['task_id'] + task = self.task_manager.task_queue.tasks.get(task_id) + + if task and task.result and task.result.data: + task_new_targets = task.result.data.get('new_targets', []) + for target in task_new_targets: + if _is_valid_domain(target) or _is_valid_ip(target): + new_targets.add(target) + + return new_targets def _update_session_state(self) -> None: """ Update the scanner state in Redis for GUI updates. - This ensures the web interface sees real-time updates. """ if self.session_id: try: @@ -421,254 +514,11 @@ class Scanner: except Exception as e: print(f"ERROR: Failed to update session state: {e}") - def _initialize_provider_states(self, target: str) -> None: - """Initialize provider states for forensic tracking.""" - if not self.graph.graph.has_node(target): - return - - node_data = self.graph.graph.nodes[target] - if 'metadata' not in node_data: - node_data['metadata'] = {} - if 'provider_states' not in node_data['metadata']: - node_data['metadata']['provider_states'] = {} - - def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: - """Get providers eligible for querying this target.""" - if dns_only: - return [p for p in self.providers if p.get_name() == 'dns'] - - eligible = [] - target_key = 'ips' if is_ip else 'domains' - - for provider in self.providers: - if provider.get_eligibility().get(target_key): - if not self._already_queried_provider(target, provider.get_name()): - eligible.append(provider) - else: - print(f"Skipping {provider.get_name()} for {target} - already queried") - - return eligible - - def _already_queried_provider(self, target: str, provider_name: str) -> bool: - """Check if we already queried a provider for a target.""" - if not self.graph.graph.has_node(target): - return False - - node_data = self.graph.graph.nodes[target] - provider_states = node_data.get('metadata', {}).get('provider_states', {}) - return provider_name in provider_states - - def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List: - """Query a single provider with stop signal checking.""" - provider_name = provider.get_name() - start_time = datetime.now(timezone.utc) - - if self._is_stop_requested(): - print(f"Stop requested before querying {provider_name} for {target}") - return [] - - print(f"Querying {provider_name} for {target}") - - self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}") - - try: - if is_ip: - results = provider.query_ip(target) - else: - results = provider.query_domain(target) - - if self._is_stop_requested(): - print(f"Stop requested after querying {provider_name} for {target}") - return [] - - self._update_provider_state(target, provider_name, 'success', len(results), None, start_time) - - print(f"✓ {provider_name} returned {len(results)} results for {target}") - return results - - except Exception as e: - self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) - print(f"✗ {provider_name} failed for {target}: {e}") - return [] - - def _update_provider_state(self, target: str, provider_name: str, status: str, - results_count: int, error: str, start_time: datetime) -> None: - """Update provider state in node metadata for forensic tracking.""" - if not self.graph.graph.has_node(target): - return - - node_data = self.graph.graph.nodes[target] - if 'metadata' not in node_data: - node_data['metadata'] = {} - if 'provider_states' not in node_data['metadata']: - node_data['metadata']['provider_states'] = {} - - node_data['metadata']['provider_states'][provider_name] = { - 'status': status, - 'timestamp': start_time.isoformat(), - 'results_count': results_count, - 'error': error, - 'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 - } - - self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)") - - def _process_provider_results_forensic(self, target: str, provider, results: List, - node_attributes: Dict, current_depth: int) -> Tuple[Set[str], bool]: - """Process provider results, returns (discovered_targets, is_large_entity).""" - provider_name = provider.get_name() - discovered_targets = set() - - if self._is_stop_requested(): - print(f"Stop requested before processing results from {provider_name} for {target}") - return discovered_targets, False - - if len(results) > self.config.large_entity_threshold: - print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}") - members = self._create_large_entity(target, provider_name, results, current_depth) - return members, True - - for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results): - if i % 10 == 0 and self._is_stop_requested(): - print(f"Stop requested while processing results from {provider_name} for {target}") - break - - self.logger.log_relationship_discovery( - source_node=source, - target_node=rel_target, - relationship_type=rel_type, - confidence_score=confidence, - provider=provider_name, - raw_data=raw_data, - discovery_method=f"{provider_name}_query_depth_{current_depth}" - ) - - self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source]) - - if _is_valid_ip(rel_target): - self.graph.add_node(rel_target, NodeType.IP) - if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): - print(f"Added IP relationship: {source} -> {rel_target} ({rel_type})") - discovered_targets.add(rel_target) - - elif rel_target.startswith('AS') and rel_target[2:].isdigit(): - self.graph.add_node(rel_target, NodeType.ASN) - if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): - print(f"Added ASN relationship: {source} -> {rel_target} ({rel_type})") - - elif _is_valid_domain(rel_target): - self.graph.add_node(rel_target, NodeType.DOMAIN) - if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): - print(f"Added domain relationship: {source} -> {rel_target} ({rel_type})") - discovered_targets.add(rel_target) - self._collect_node_attributes(rel_target, provider_name, rel_type, source, raw_data, node_attributes[rel_target]) - - else: - self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source]) - - return discovered_targets, False - - def _create_large_entity(self, source: str, provider_name: str, results: List, current_depth: int) -> Set[str]: - """Create a large entity node and returns the members for DNS processing.""" - entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" - - targets = [rel[1] for rel in results if len(rel) > 1] - node_type = 'unknown' - - if targets: - if _is_valid_domain(targets[0]): - node_type = 'domain' - elif _is_valid_ip(targets[0]): - node_type = 'ip' - - for target in targets: - self.graph.add_node(target, NodeType.DOMAIN if node_type == 'domain' else NodeType.IP) - - attributes = { - 'count': len(targets), - 'nodes': targets, - 'node_type': node_type, - 'source_provider': provider_name, - 'discovery_depth': current_depth, - 'threshold_exceeded': self.config.large_entity_threshold, - } - description = f'Large entity created due to {len(targets)} results from {provider_name}' - - self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes, description=description) - - if results: - rel_type = results[0][2] - self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, - {'large_entity_info': f'Contains {len(targets)} {node_type}s'}) - - self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") - print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}") - - return set(targets) - - def _collect_node_attributes(self, node_id: str, provider_name: str, rel_type: str, - target: str, raw_data: Dict[str, Any], attributes: Dict[str, Any]) -> None: - """Collect and organize attributes for a node.""" - self.logger.logger.debug(f"Collecting attributes for {node_id} from {provider_name}: {rel_type}") - - if provider_name == 'dns': - record_type = raw_data.get('query_type', 'UNKNOWN') - value = raw_data.get('value', target) - dns_entry = f"{record_type}: {value}" - if dns_entry not in attributes.get('dns_records', []): - attributes.setdefault('dns_records', []).append(dns_entry) - - elif provider_name == 'crtsh': - if rel_type == "san_certificate": - domain_certs = raw_data.get('domain_certificates', {}) - if node_id in domain_certs: - cert_summary = domain_certs[node_id] - attributes['certificates'] = cert_summary - if target not in attributes.get('related_domains_san', []): - attributes.setdefault('related_domains_san', []).append(target) - - elif provider_name == 'shodan': - shodan_attributes = attributes.setdefault('shodan', {}) - for key, value in raw_data.items(): - if key not in shodan_attributes or not shodan_attributes.get(key): - shodan_attributes[key] = value - - if rel_type == "asn_membership": - attributes['asn'] = { - 'id': target, - 'description': raw_data.get('org', ''), - 'isp': raw_data.get('isp', ''), - 'country': raw_data.get('country', '') - } - - record_type_name = rel_type - if record_type_name not in attributes: - attributes[record_type_name] = [] - - if isinstance(target, list): - attributes[record_type_name].extend(target) - else: - if target not in attributes[record_type_name]: - attributes[record_type_name].append(target) - - def _log_target_processing_error(self, target: str, error: str) -> None: - """Log target processing errors for forensic trail.""" - self.logger.logger.error(f"Target processing failed for {target}: {error}") - - def _log_provider_error(self, target: str, provider_name: str, error: str) -> None: - """Log provider query errors for forensic trail.""" - self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") - - def _log_no_eligible_providers(self, target: str, is_ip: bool) -> None: - """Log when no providers are eligible for a target.""" - target_type = 'IP' if is_ip else 'domain' - self.logger.logger.warning(f"No eligible providers for {target_type}: {target}") - def stop_scan(self) -> bool: - """Request immediate scan termination with immediate GUI feedback.""" + """Request immediate scan termination with task manager cleanup.""" try: - print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") - self.logger.logger.info("Scan termination requested by user") + print("=== INITIATING ENHANCED SCAN TERMINATION ===") + self.logger.logger.info("Enhanced scan termination requested by user") # Invalidate current scan ID to prevent stale updates old_scan_id = self.current_scan_id @@ -678,28 +528,29 @@ class Scanner: # Set stop signals self._set_stop_signal() self.status = ScanStatus.STOPPED + self.scan_end_time = datetime.now(timezone.utc) # Immediately update GUI with stopped status self._update_session_state() - # Cancel executor futures if running - if self.executor: - print("Shutting down executor with immediate cancellation...") - self.executor.shutdown(wait=False, cancel_futures=True) + # Stop task manager if running + if self.task_manager: + print("Stopping task manager...") + self.task_manager.stop_execution() - print("Termination signals sent. The scan will stop as soon as possible.") + print("Enhanced termination signals sent. The scan will stop as soon as possible.") return True except Exception as e: - print(f"ERROR: Exception in stop_scan: {e}") - self.logger.logger.error(f"Error during scan termination: {e}") + print(f"ERROR: Exception in enhanced stop_scan: {e}") + self.logger.logger.error(f"Error during enhanced scan termination: {e}") traceback.print_exc() return False def get_scan_status(self) -> Dict[str, Any]: - """Get current scan status with forensic information.""" + """Get current scan status with enhanced task-based information.""" try: - return { + status = { 'status': self.status, 'target_domain': self.current_target, 'current_depth': self.current_depth, @@ -709,8 +560,33 @@ class Scanner: 'indicators_processed': self.indicators_processed, 'progress_percentage': self._calculate_progress(), 'enabled_providers': [provider.get_name() for provider in self.providers], - 'graph_statistics': self.graph.get_statistics() + 'graph_statistics': self.graph.get_statistics(), + 'scan_duration_seconds': self._calculate_scan_duration(), + 'scan_start_time': self.scan_start_time.isoformat() if self.scan_start_time else None, + 'scan_end_time': self.scan_end_time.isoformat() if self.scan_end_time else None } + + # Add task manager statistics if available + if self.task_manager: + progress_report = self.task_manager.get_progress_report() + status['task_statistics'] = progress_report['statistics'] + status['task_details'] = { + 'is_running': progress_report['is_running'], + 'worker_count': progress_report['worker_count'], + 'failed_tasks_count': len(progress_report['failed_tasks']) + } + + # Update indicators processed from task statistics + task_stats = progress_report['statistics'] + status['indicators_processed'] = task_stats['succeeded'] + task_stats['failed_permanent'] + + # Recalculate progress based on task completion + if task_stats['total_tasks'] > 0: + task_completion_rate = (task_stats['succeeded'] + task_stats['failed_permanent']) / task_stats['total_tasks'] + status['progress_percentage'] = min(100.0, task_completion_rate * 100.0) + + return status + except Exception as e: print(f"ERROR: Exception in get_scan_status: {e}") traceback.print_exc() @@ -724,7 +600,9 @@ class Scanner: 'indicators_processed': 0, 'progress_percentage': 0.0, 'enabled_providers': [], - 'graph_statistics': {} + 'graph_statistics': {}, + 'scan_duration_seconds': 0, + 'error': str(e) } def _calculate_progress(self) -> float: @@ -733,12 +611,21 @@ class Scanner: return 0.0 return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100) + def _calculate_scan_duration(self) -> float: + """Calculate scan duration in seconds.""" + if not self.scan_start_time: + return 0.0 + + end_time = self.scan_end_time or datetime.now(timezone.utc) + duration = (end_time - self.scan_start_time).total_seconds() + return round(duration, 2) + def get_graph_data(self) -> Dict[str, Any]: """Get current graph data for visualization.""" return self.graph.get_graph_data() def export_results(self) -> Dict[str, Any]: - """Export complete scan results with forensic audit trail.""" + """Export complete scan results with enhanced task-based audit trail.""" graph_data = self.graph.export_json() audit_trail = self.logger.export_audit_trail() provider_stats = {} @@ -752,24 +639,66 @@ class Scanner: 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, 'enabled_providers': list(provider_stats.keys()), - 'session_id': self.session_id + 'session_id': self.session_id, + 'scan_id': self.current_scan_id, + 'scan_duration_seconds': self._calculate_scan_duration(), + 'scan_start_time': self.scan_start_time.isoformat() if self.scan_start_time else None, + 'scan_end_time': self.scan_end_time.isoformat() if self.scan_end_time else None }, 'graph_data': graph_data, 'forensic_audit': audit_trail, 'provider_statistics': provider_stats, 'scan_summary': self.logger.get_forensic_summary() } + + # Add task execution details if available + if self.task_manager: + progress_report = self.task_manager.get_progress_report() + export_data['task_execution'] = { + 'statistics': progress_report['statistics'], + 'failed_tasks': progress_report['failed_tasks'], + 'execution_summary': { + 'total_tasks_created': progress_report['statistics']['total_tasks'], + 'success_rate': progress_report['statistics']['completion_rate'], + 'average_retries': self._calculate_average_retries(progress_report) + } + } + return export_data + def _calculate_average_retries(self, progress_report: Dict[str, Any]) -> float: + """Calculate average retry attempts across all tasks.""" + if not self.task_manager or not hasattr(self.task_manager.task_queue, 'tasks'): + return 0.0 + + total_attempts = 0 + task_count = 0 + + for task in self.task_manager.task_queue.tasks.values(): + if hasattr(task, 'execution_history'): + total_attempts += len(task.execution_history) + task_count += 1 + + return round(total_attempts / task_count, 2) if task_count > 0 else 0.0 + def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: - """Get statistics for all providers with forensic information.""" + """Get statistics for all providers with enhanced cache information.""" stats = {} for provider in self.providers: - stats[provider.get_name()] = provider.get_statistics() + provider_stats = provider.get_statistics() + # Add cache performance metrics + if hasattr(provider, 'cache'): + cache_performance = { + 'cache_enabled': True, + 'cache_directory': provider.cache.cache_dir, + 'cache_expiry_hours': provider.cache.cache_expiry / 3600 + } + provider_stats.update(cache_performance) + stats[provider.get_name()] = provider_stats return stats def get_provider_info(self) -> Dict[str, Dict[str, Any]]: - """Get information about all available providers.""" + """Get information about all available providers with enhanced details.""" info = {} provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): @@ -788,13 +717,25 @@ class Scanner: # Find the actual provider instance if it exists, to get live stats live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) - info[provider_name] = { + provider_info = { 'display_name': temp_provider.get_display_name(), 'requires_api_key': temp_provider.requires_api_key(), 'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(), 'enabled': self.config.is_provider_enabled(provider_name), 'rate_limit': self.config.get_rate_limit(provider_name), + 'eligibility': temp_provider.get_eligibility() } + + # Add cache information if provider has caching + if live_provider and hasattr(live_provider, 'cache'): + provider_info['cache_info'] = { + 'cache_enabled': True, + 'cache_directory': live_provider.cache.cache_dir, + 'cache_expiry_hours': live_provider.cache.cache_expiry / 3600 + } + + info[provider_name] = provider_info + except Exception as e: print(f"✗ Failed to get info for provider from {filename}: {e}") traceback.print_exc() diff --git a/core/session_config.py b/core/session_config.py index 3545b14..ed17099 100644 --- a/core/session_config.py +++ b/core/session_config.py @@ -1,6 +1,6 @@ """ -Per-session configuration management for DNSRecon. -Provides isolated configuration instances for each user session. +Enhanced per-session configuration management for DNSRecon. +Provides isolated configuration instances for each user session while supporting global caching. """ import os @@ -9,12 +9,12 @@ from typing import Dict, Optional class SessionConfig: """ - Session-specific configuration that inherits from global config - but maintains isolated API keys and provider settings. + Enhanced session-specific configuration that inherits from global config + but maintains isolated API keys and provider settings while supporting global caching. """ def __init__(self): - """Initialize session config with global defaults.""" + """Initialize enhanced session config with global cache support.""" # Copy all attributes from global config self.api_keys: Dict[str, Optional[str]] = { 'shodan': None @@ -26,20 +26,39 @@ class SessionConfig: self.max_concurrent_requests = 5 self.large_entity_threshold = 100 - # Rate limiting settings (per session) + # Enhanced rate limiting settings (per session) self.rate_limits = { 'crtsh': 60, 'shodan': 60, 'dns': 100 } - # Provider settings (per session) + # Enhanced provider settings (per session) self.enabled_providers = { 'crtsh': True, 'dns': True, 'shodan': False } + # Task-based execution settings + self.task_retry_settings = { + 'max_retries': 3, + 'base_backoff_seconds': 1.0, + 'max_backoff_seconds': 60.0, + 'retry_on_rate_limit': True, + 'retry_on_connection_error': True, + 'retry_on_timeout': True + } + + # Cache settings (global across all sessions) + self.cache_settings = { + 'enabled': True, + 'expiry_hours': 12, + 'cache_base_dir': '.cache', + 'per_provider_directories': True, + 'thread_safe_operations': True + } + # Logging configuration self.log_level = 'INFO' self.log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' @@ -48,6 +67,22 @@ class SessionConfig: self.flask_host = '127.0.0.1' self.flask_port = 5000 self.flask_debug = True + + # Session isolation settings + self.session_isolation = { + 'enforce_single_session_per_user': True, + 'consolidate_session_data_on_replacement': True, + 'user_fingerprinting_enabled': True, + 'session_timeout_minutes': 60 + } + + # Circuit breaker settings for provider reliability + self.circuit_breaker = { + 'enabled': True, + 'failure_threshold': 5, # Failures before opening circuit + 'recovery_timeout_seconds': 300, # 5 minutes before trying again + 'half_open_max_calls': 3 # Test calls when recovering + } def set_api_key(self, provider: str, api_key: str) -> bool: """ @@ -55,14 +90,19 @@ class SessionConfig: Args: provider: Provider name (shodan, etc) - api_key: API key string + api_key: API key string (empty string to clear) Returns: bool: True if key was set successfully """ if provider in self.api_keys: - self.api_keys[provider] = api_key - self.enabled_providers[provider] = True if api_key else False + # Handle clearing of API keys + if api_key and api_key.strip(): + self.api_keys[provider] = api_key.strip() + self.enabled_providers[provider] = True + else: + self.api_keys[provider] = None + self.enabled_providers[provider] = False return True return False @@ -102,19 +142,231 @@ class SessionConfig: """ return self.rate_limits.get(provider, 60) + def get_task_retry_config(self) -> Dict[str, any]: + """ + Get task retry configuration for this session. + + Returns: + Dictionary with retry settings + """ + return self.task_retry_settings.copy() + + def get_cache_config(self) -> Dict[str, any]: + """ + Get cache configuration (global settings). + + Returns: + Dictionary with cache settings + """ + return self.cache_settings.copy() + + def is_circuit_breaker_enabled(self) -> bool: + """Check if circuit breaker is enabled for provider reliability.""" + return self.circuit_breaker.get('enabled', True) + + def get_circuit_breaker_config(self) -> Dict[str, any]: + """Get circuit breaker configuration.""" + return self.circuit_breaker.copy() + + def update_provider_settings(self, provider_updates: Dict[str, Dict[str, any]]) -> bool: + """ + Update provider-specific settings in bulk. + + Args: + provider_updates: Dictionary of provider -> settings updates + + Returns: + bool: True if updates were applied successfully + """ + try: + for provider_name, updates in provider_updates.items(): + # Update rate limits + if 'rate_limit' in updates: + self.rate_limits[provider_name] = updates['rate_limit'] + + # Update enabled status + if 'enabled' in updates: + self.enabled_providers[provider_name] = updates['enabled'] + + # Update API key + if 'api_key' in updates: + self.set_api_key(provider_name, updates['api_key']) + + return True + except Exception as e: + print(f"Error updating provider settings: {e}") + return False + + def validate_configuration(self) -> Dict[str, any]: + """ + Validate the current configuration and return validation results. + + Returns: + Dictionary with validation results and any issues found + """ + validation_result = { + 'valid': True, + 'warnings': [], + 'errors': [], + 'provider_status': {} + } + + # Validate provider configurations + for provider_name, enabled in self.enabled_providers.items(): + provider_status = { + 'enabled': enabled, + 'has_api_key': bool(self.api_keys.get(provider_name)), + 'rate_limit': self.rate_limits.get(provider_name, 60) + } + + # Check for potential issues + if enabled and provider_name in ['shodan'] and not provider_status['has_api_key']: + validation_result['warnings'].append( + f"Provider '{provider_name}' is enabled but missing API key" + ) + + validation_result['provider_status'][provider_name] = provider_status + + # Validate task settings + if self.task_retry_settings['max_retries'] > 10: + validation_result['warnings'].append( + f"High retry count ({self.task_retry_settings['max_retries']}) may cause long delays" + ) + + # Validate concurrent settings + if self.max_concurrent_requests > 10: + validation_result['warnings'].append( + f"High concurrency ({self.max_concurrent_requests}) may overwhelm providers" + ) + + # Validate cache settings + if not os.path.exists(self.cache_settings['cache_base_dir']): + try: + os.makedirs(self.cache_settings['cache_base_dir'], exist_ok=True) + except Exception as e: + validation_result['errors'].append(f"Cannot create cache directory: {e}") + validation_result['valid'] = False + + return validation_result + def load_from_env(self): - """Load configuration from environment variables (only if not already set).""" + """Load configuration from environment variables with enhanced validation.""" + # Load API keys from environment if os.getenv('SHODAN_API_KEY') and not self.api_keys['shodan']: self.set_api_key('shodan', os.getenv('SHODAN_API_KEY')) + print("Loaded Shodan API key from environment") # Override default settings from environment self.default_recursion_depth = int(os.getenv('DEFAULT_RECURSION_DEPTH', '2')) - self.default_timeout = 30 - self.max_concurrent_requests = 5 + self.default_timeout = int(os.getenv('DEFAULT_TIMEOUT', '30')) + self.max_concurrent_requests = int(os.getenv('MAX_CONCURRENT_REQUESTS', '5')) + + # Load task retry settings from environment + if os.getenv('TASK_MAX_RETRIES'): + self.task_retry_settings['max_retries'] = int(os.getenv('TASK_MAX_RETRIES')) + + if os.getenv('TASK_BASE_BACKOFF'): + self.task_retry_settings['base_backoff_seconds'] = float(os.getenv('TASK_BASE_BACKOFF')) + + # Load cache settings from environment + if os.getenv('CACHE_EXPIRY_HOURS'): + self.cache_settings['expiry_hours'] = int(os.getenv('CACHE_EXPIRY_HOURS')) + + if os.getenv('CACHE_DISABLED'): + self.cache_settings['enabled'] = os.getenv('CACHE_DISABLED').lower() != 'true' + + # Load circuit breaker settings + if os.getenv('CIRCUIT_BREAKER_DISABLED'): + self.circuit_breaker['enabled'] = os.getenv('CIRCUIT_BREAKER_DISABLED').lower() != 'true' + + # Flask settings + self.flask_debug = os.getenv('FLASK_DEBUG', 'True').lower() == 'true' + + print("Enhanced configuration loaded from environment") + + def export_config_summary(self) -> Dict[str, any]: + """ + Export a summary of the current configuration for debugging/logging. + + Returns: + Dictionary with configuration summary (API keys redacted) + """ + return { + 'providers': { + provider: { + 'enabled': self.enabled_providers.get(provider, False), + 'has_api_key': bool(self.api_keys.get(provider)), + 'rate_limit': self.rate_limits.get(provider, 60) + } + for provider in self.enabled_providers.keys() + }, + 'task_settings': { + 'max_retries': self.task_retry_settings['max_retries'], + 'max_concurrent_requests': self.max_concurrent_requests, + 'large_entity_threshold': self.large_entity_threshold + }, + 'cache_settings': { + 'enabled': self.cache_settings['enabled'], + 'expiry_hours': self.cache_settings['expiry_hours'], + 'base_directory': self.cache_settings['cache_base_dir'] + }, + 'session_settings': { + 'isolation_enabled': self.session_isolation['enforce_single_session_per_user'], + 'consolidation_enabled': self.session_isolation['consolidate_session_data_on_replacement'], + 'timeout_minutes': self.session_isolation['session_timeout_minutes'] + }, + 'circuit_breaker': { + 'enabled': self.circuit_breaker['enabled'], + 'failure_threshold': self.circuit_breaker['failure_threshold'], + 'recovery_timeout': self.circuit_breaker['recovery_timeout_seconds'] + } + } def create_session_config() -> SessionConfig: - """Create a new session configuration instance.""" + """ + Create a new enhanced session configuration instance. + + Returns: + Configured SessionConfig instance + """ session_config = SessionConfig() session_config.load_from_env() - return session_config \ No newline at end of file + + # Validate configuration and log any issues + validation = session_config.validate_configuration() + if validation['warnings']: + print("Configuration warnings:") + for warning in validation['warnings']: + print(f" WARNING: {warning}") + + if validation['errors']: + print("Configuration errors:") + for error in validation['errors']: + print(f" ERROR: {error}") + + if not validation['valid']: + raise ValueError("Configuration validation failed - see errors above") + + print(f"Enhanced session configuration created successfully") + return session_config + + +def create_test_config() -> SessionConfig: + """ + Create a test configuration with safe defaults for testing. + + Returns: + Test-safe SessionConfig instance + """ + test_config = SessionConfig() + + # Override settings for testing + test_config.max_concurrent_requests = 2 + test_config.task_retry_settings['max_retries'] = 1 + test_config.task_retry_settings['base_backoff_seconds'] = 0.1 + test_config.cache_settings['expiry_hours'] = 1 + test_config.session_isolation['session_timeout_minutes'] = 10 + + print("Test configuration created") + return test_config \ No newline at end of file diff --git a/core/session_manager.py b/core/session_manager.py index 06bc683..a27c981 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -5,37 +5,153 @@ import time import uuid import redis import pickle -from typing import Dict, Optional, Any, List +import hashlib +from typing import Dict, Optional, Any, List, Tuple from core.scanner import Scanner -# WARNING: Using pickle can be a security risk if the data source is not trusted. -# In this case, we are only serializing/deserializing our own trusted Scanner objects, -# which is generally safe. Do not unpickle data from untrusted sources. + +class UserIdentifier: + """Handles user identification for session management.""" + + @staticmethod + def generate_user_fingerprint(client_ip: str, user_agent: str) -> str: + """ + Generate a unique fingerprint for a user based on IP and User-Agent. + + Args: + client_ip: Client IP address + user_agent: User-Agent header value + + Returns: + Unique user fingerprint hash + """ + # Create deterministic user identifier + user_data = f"{client_ip}:{user_agent[:100]}" # Limit UA to 100 chars + fingerprint = hashlib.sha256(user_data.encode()).hexdigest()[:16] # 16 char fingerprint + return f"user_{fingerprint}" + + @staticmethod + def extract_request_info(request) -> Tuple[str, str]: + """ + Extract client IP and User-Agent from Flask request. + + Args: + request: Flask request object + + Returns: + Tuple of (client_ip, user_agent) + """ + # Handle proxy headers for real IP + client_ip = request.headers.get('X-Forwarded-For', '').split(',')[0].strip() + if not client_ip: + client_ip = request.headers.get('X-Real-IP', '') + if not client_ip: + client_ip = request.remote_addr or 'unknown' + + user_agent = request.headers.get('User-Agent', 'unknown') + + return client_ip, user_agent + + +class SessionConsolidator: + """Handles consolidation of session data when replacing sessions.""" + + @staticmethod + def consolidate_scanner_data(old_scanner: 'Scanner', new_scanner: 'Scanner') -> 'Scanner': + """ + Consolidate useful data from old scanner into new scanner. + + Args: + old_scanner: Scanner from terminated session + new_scanner: New scanner instance + + Returns: + Enhanced new scanner with consolidated data + """ + try: + # Consolidate graph data if old scanner has valuable data + if old_scanner and hasattr(old_scanner, 'graph') and old_scanner.graph: + old_stats = old_scanner.graph.get_statistics() + if old_stats['basic_metrics']['total_nodes'] > 0: + print(f"Consolidating graph data: {old_stats['basic_metrics']['total_nodes']} nodes, {old_stats['basic_metrics']['total_edges']} edges") + + # Transfer nodes and edges to new scanner's graph + for node_id, node_data in old_scanner.graph.graph.nodes(data=True): + # Add node to new graph with all attributes + new_scanner.graph.graph.add_node(node_id, **node_data) + + for source, target, edge_data in old_scanner.graph.graph.edges(data=True): + # Add edge to new graph with all attributes + new_scanner.graph.graph.add_edge(source, target, **edge_data) + + # Update correlation index + if hasattr(old_scanner.graph, 'correlation_index'): + new_scanner.graph.correlation_index = old_scanner.graph.correlation_index.copy() + + # Update timestamps + new_scanner.graph.creation_time = old_scanner.graph.creation_time + new_scanner.graph.last_modified = old_scanner.graph.last_modified + + # Consolidate provider statistics + if old_scanner and hasattr(old_scanner, 'providers') and old_scanner.providers: + for old_provider in old_scanner.providers: + # Find matching provider in new scanner + matching_new_provider = None + for new_provider in new_scanner.providers: + if new_provider.get_name() == old_provider.get_name(): + matching_new_provider = new_provider + break + + if matching_new_provider: + # Transfer cumulative statistics + matching_new_provider.total_requests += old_provider.total_requests + matching_new_provider.successful_requests += old_provider.successful_requests + matching_new_provider.failed_requests += old_provider.failed_requests + matching_new_provider.total_relationships_found += old_provider.total_relationships_found + + # Transfer cache statistics if available + if hasattr(old_provider, 'cache_hits'): + matching_new_provider.cache_hits += getattr(old_provider, 'cache_hits', 0) + matching_new_provider.cache_misses += getattr(old_provider, 'cache_misses', 0) + + print(f"Consolidated {old_provider.get_name()} provider stats: {old_provider.total_requests} requests") + + return new_scanner + + except Exception as e: + print(f"Warning: Error during session consolidation: {e}") + return new_scanner + class SessionManager: """ - Manages multiple scanner instances for concurrent user sessions using Redis. + Manages single scanner session per user using Redis with user identification. + Enforces one active session per user for consistent state management. """ def __init__(self, session_timeout_minutes: int = 60): """ - Initialize session manager with a Redis backend. + Initialize session manager with Redis backend and user tracking. """ self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.session_timeout = session_timeout_minutes * 60 # Convert to seconds - self.lock = threading.Lock() # Lock for local operations, Redis handles atomic ops + self.lock = threading.Lock() + + # User identification helper + self.user_identifier = UserIdentifier() + self.consolidator = SessionConsolidator() # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() - print(f"SessionManager initialized with Redis backend and {session_timeout_minutes}min timeout") + print(f"SessionManager initialized with Redis backend, user tracking, and {session_timeout_minutes}min timeout") def __getstate__(self): """Prepare SessionManager for pickling.""" state = self.__dict__.copy() - # Exclude unpickleable attributes - Redis client and threading objects + # Exclude unpickleable attributes unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client'] for attr in unpicklable_attrs: if attr in state: @@ -53,67 +169,108 @@ class SessionManager: self.cleanup_thread.start() def _get_session_key(self, session_id: str) -> str: - """Generates the Redis key for a session.""" + """Generate Redis key for a session.""" return f"dnsrecon:session:{session_id}" + def _get_user_session_key(self, user_fingerprint: str) -> str: + """Generate Redis key for user -> session mapping.""" + return f"dnsrecon:user:{user_fingerprint}" + def _get_stop_signal_key(self, session_id: str) -> str: - """Generates the Redis key for a session's stop signal.""" + """Generate Redis key for session stop signal.""" return f"dnsrecon:stop:{session_id}" - def create_session(self) -> str: + def create_or_replace_user_session(self, client_ip: str, user_agent: str) -> str: """ - Create a new user session and store it in Redis. + Create new session for user, replacing any existing session. + Consolidates data from previous session if it exists. + + Args: + client_ip: Client IP address + user_agent: User-Agent header + + Returns: + New session ID """ - session_id = str(uuid.uuid4()) - print(f"=== CREATING SESSION {session_id} IN REDIS ===") + user_fingerprint = self.user_identifier.generate_user_fingerprint(client_ip, user_agent) + new_session_id = str(uuid.uuid4()) + + print(f"=== CREATING/REPLACING SESSION FOR USER {user_fingerprint} ===") try: + # Check for existing user session + existing_session_id = self._get_user_current_session(user_fingerprint) + old_scanner = None + + if existing_session_id: + print(f"Found existing session {existing_session_id} for user {user_fingerprint}") + # Get old scanner data for consolidation + old_scanner = self.get_session(existing_session_id) + # Terminate old session + self._terminate_session_internal(existing_session_id, cleanup_user_mapping=False) + print(f"Terminated old session {existing_session_id}") + + # Create new session config and scanner from core.session_config import create_session_config session_config = create_session_config() - scanner_instance = Scanner(session_config=session_config) + new_scanner = Scanner(session_config=session_config) - # Set the session ID on the scanner for cross-process stop signal management - scanner_instance.session_id = session_id + # Set session ID on scanner for cross-process operations + new_scanner.session_id = new_session_id + # Consolidate data from old session if available + if old_scanner: + new_scanner = self.consolidator.consolidate_scanner_data(old_scanner, new_scanner) + print(f"Consolidated data from previous session") + + # Create session data session_data = { - 'scanner': scanner_instance, + 'scanner': new_scanner, 'config': session_config, 'created_at': time.time(), 'last_activity': time.time(), - 'status': 'active' + 'status': 'active', + 'user_fingerprint': user_fingerprint, + 'client_ip': client_ip, + 'user_agent': user_agent[:200] # Truncate for storage } - # Serialize the entire session data dictionary using pickle + # Store session in Redis + session_key = self._get_session_key(new_session_id) serialized_data = pickle.dumps(session_data) - - # Store in Redis - session_key = self._get_session_key(session_id) self.redis_client.setex(session_key, self.session_timeout, serialized_data) - # Initialize stop signal as False - stop_key = self._get_stop_signal_key(session_id) + # Update user -> session mapping + user_session_key = self._get_user_session_key(user_fingerprint) + self.redis_client.setex(user_session_key, self.session_timeout, new_session_id.encode('utf-8')) + + # Initialize stop signal + stop_key = self._get_stop_signal_key(new_session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') - print(f"Session {session_id} stored in Redis with stop signal initialized") - return session_id + print(f"Created new session {new_session_id} for user {user_fingerprint}") + return new_session_id except Exception as e: - print(f"ERROR: Failed to create session {session_id}: {e}") + print(f"ERROR: Failed to create session for user {user_fingerprint}: {e}") raise + def _get_user_current_session(self, user_fingerprint: str) -> Optional[str]: + """Get current session ID for a user.""" + try: + user_session_key = self._get_user_session_key(user_fingerprint) + session_id_bytes = self.redis_client.get(user_session_key) + if session_id_bytes: + return session_id_bytes.decode('utf-8') + return None + except Exception as e: + print(f"Error getting user session: {e}") + return None + def set_stop_signal(self, session_id: str) -> bool: - """ - Set the stop signal for a session (cross-process safe). - - Args: - session_id: Session identifier - - Returns: - bool: True if signal was set successfully - """ + """Set stop signal for session (cross-process safe).""" try: stop_key = self._get_stop_signal_key(session_id) - # Set stop signal to '1' with the same TTL as the session self.redis_client.setex(stop_key, self.session_timeout, b'1') print(f"Stop signal set for session {session_id}") return True @@ -122,15 +279,7 @@ class SessionManager: return False def is_stop_requested(self, session_id: str) -> bool: - """ - Check if stop is requested for a session (cross-process safe). - - Args: - session_id: Session identifier - - Returns: - bool: True if stop is requested - """ + """Check if stop is requested for session (cross-process safe).""" try: stop_key = self._get_stop_signal_key(session_id) value = self.redis_client.get(stop_key) @@ -140,15 +289,7 @@ class SessionManager: return False def clear_stop_signal(self, session_id: str) -> bool: - """ - Clear the stop signal for a session. - - Args: - session_id: Session identifier - - Returns: - bool: True if signal was cleared successfully - """ + """Clear stop signal for session.""" try: stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') @@ -159,13 +300,13 @@ class SessionManager: return False def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: - """Retrieves and deserializes session data from Redis.""" + """Retrieve and deserialize session data from Redis.""" try: session_key = self._get_session_key(session_id) serialized_data = self.redis_client.get(session_key) if serialized_data: session_data = pickle.loads(serialized_data) - # Ensure the scanner has the correct session ID for stop signal checking + # Ensure scanner has correct session ID if 'scanner' in session_data and session_data['scanner']: session_data['scanner'].session_id = session_id return session_data @@ -175,37 +316,32 @@ class SessionManager: return None def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: - """ - Serializes and saves session data back to Redis with updated TTL. - - Returns: - bool: True if save was successful - """ + """Serialize and save session data to Redis with updated TTL.""" try: session_key = self._get_session_key(session_id) serialized_data = pickle.dumps(session_data) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) + + # Also refresh user mapping TTL if available + if 'user_fingerprint' in session_data: + user_session_key = self._get_user_session_key(session_data['user_fingerprint']) + self.redis_client.setex(user_session_key, self.session_timeout, session_id.encode('utf-8')) + return result except Exception as e: print(f"ERROR: Failed to save session data for {session_id}: {e}") return False def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: - """ - Updates just the scanner object in a session with immediate persistence. - - Returns: - bool: True if update was successful - """ + """Update scanner object in session with immediate persistence.""" try: session_data = self._get_session_data(session_id) if session_data: - # Ensure scanner has the session ID + # Ensure scanner has session ID scanner.session_id = session_id session_data['scanner'] = scanner session_data['last_activity'] = time.time() - # Immediately save to Redis for GUI updates success = self._save_session_data(session_id, session_data) if success: print(f"Scanner state updated for session {session_id} (status: {scanner.status})") @@ -220,16 +356,7 @@ class SessionManager: return False def update_scanner_status(self, session_id: str, status: str) -> bool: - """ - Quickly update just the scanner status for immediate GUI feedback. - - Args: - session_id: Session identifier - status: New scanner status - - Returns: - bool: True if update was successful - """ + """Quickly update scanner status for immediate GUI feedback.""" try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: @@ -248,9 +375,7 @@ class SessionManager: return False def get_session(self, session_id: str) -> Optional[Scanner]: - """ - Get scanner instance for a session from Redis with session ID management. - """ + """Get scanner instance for session with session ID management.""" if not session_id: return None @@ -265,21 +390,13 @@ class SessionManager: scanner = session_data.get('scanner') if scanner: - # Ensure the scanner can check the Redis-based stop signal + # Ensure scanner can check Redis-based stop signal scanner.session_id = session_id return scanner def get_session_status_only(self, session_id: str) -> Optional[str]: - """ - Get just the scanner status without full session retrieval (for performance). - - Args: - session_id: Session identifier - - Returns: - Scanner status string or None if not found - """ + """Get scanner status without full session retrieval (for performance).""" try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: @@ -290,16 +407,18 @@ class SessionManager: return None def terminate_session(self, session_id: str) -> bool: - """ - Terminate a specific session in Redis with reliable stop signal and immediate status update. - """ + """Terminate specific session with reliable stop signal and immediate status update.""" + return self._terminate_session_internal(session_id, cleanup_user_mapping=True) + + def _terminate_session_internal(self, session_id: str, cleanup_user_mapping: bool = True) -> bool: + """Internal session termination with configurable user mapping cleanup.""" print(f"=== TERMINATING SESSION {session_id} ===") try: - # First, set the stop signal + # Set stop signal first self.set_stop_signal(session_id) - # Update scanner status to stopped immediately for GUI feedback + # Update scanner status immediately for GUI feedback self.update_scanner_status(session_id, 'stopped') session_data = self._get_session_data(session_id) @@ -310,16 +429,19 @@ class SessionManager: scanner = session_data.get('scanner') if scanner and scanner.status == 'running': print(f"Stopping scan for session: {session_id}") - # The scanner will check the Redis stop signal scanner.stop_scan() - - # Update the scanner state immediately self.update_session_scanner(session_id, scanner) - # Wait a moment for graceful shutdown + # Wait for graceful shutdown time.sleep(0.5) - # Delete session data and stop signal from Redis + # Clean up user mapping if requested + if cleanup_user_mapping and 'user_fingerprint' in session_data: + user_session_key = self._get_user_session_key(session_data['user_fingerprint']) + self.redis_client.delete(user_session_key) + print(f"Cleaned up user mapping for {session_data['user_fingerprint']}") + + # Delete session data and stop signal session_key = self._get_session_key(session_id) stop_key = self._get_stop_signal_key(session_id) self.redis_client.delete(session_key) @@ -333,22 +455,30 @@ class SessionManager: return False def _cleanup_loop(self) -> None: - """ - Background thread to cleanup inactive sessions and orphaned stop signals. - """ + """Background thread to cleanup inactive sessions and orphaned signals.""" while True: try: # Clean up orphaned stop signals stop_keys = self.redis_client.keys("dnsrecon:stop:*") for stop_key in stop_keys: - # Extract session ID from stop key session_id = stop_key.decode('utf-8').split(':')[-1] session_key = self._get_session_key(session_id) - # If session doesn't exist but stop signal does, clean it up if not self.redis_client.exists(session_key): self.redis_client.delete(stop_key) print(f"Cleaned up orphaned stop signal for session {session_id}") + + # Clean up orphaned user mappings + user_keys = self.redis_client.keys("dnsrecon:user:*") + for user_key in user_keys: + session_id_bytes = self.redis_client.get(user_key) + if session_id_bytes: + session_id = session_id_bytes.decode('utf-8') + session_key = self._get_session_key(session_id) + + if not self.redis_client.exists(session_key): + self.redis_client.delete(user_key) + print(f"Cleaned up orphaned user mapping for session {session_id}") except Exception as e: print(f"Error in cleanup loop: {e}") @@ -369,6 +499,8 @@ class SessionManager: scanner = session_data.get('scanner') sessions.append({ 'session_id': session_id, + 'user_fingerprint': session_data.get('user_fingerprint', 'unknown'), + 'client_ip': session_data.get('client_ip', 'unknown'), 'created_at': session_data.get('created_at'), 'last_activity': session_data.get('last_activity'), 'scanner_status': scanner.status if scanner else 'unknown', @@ -384,9 +516,11 @@ class SessionManager: """Get session manager statistics.""" try: session_keys = self.redis_client.keys("dnsrecon:session:*") + user_keys = self.redis_client.keys("dnsrecon:user:*") stop_keys = self.redis_client.keys("dnsrecon:stop:*") active_sessions = len(session_keys) + unique_users = len(user_keys) running_scans = 0 for session_key in session_keys: @@ -397,16 +531,46 @@ class SessionManager: return { 'total_active_sessions': active_sessions, + 'unique_users': unique_users, 'running_scans': running_scans, - 'total_stop_signals': len(stop_keys) + 'total_stop_signals': len(stop_keys), + 'average_sessions_per_user': round(active_sessions / unique_users, 2) if unique_users > 0 else 0 } except Exception as e: print(f"ERROR: Failed to get statistics: {e}") return { 'total_active_sessions': 0, + 'unique_users': 0, 'running_scans': 0, - 'total_stop_signals': 0 + 'total_stop_signals': 0, + 'average_sessions_per_user': 0 } + def get_session_info(self, session_id: str) -> Dict[str, Any]: + """Get detailed information about a specific session.""" + try: + session_data = self._get_session_data(session_id) + if not session_data: + return {'error': 'Session not found'} + + scanner = session_data.get('scanner') + + return { + 'session_id': session_id, + 'user_fingerprint': session_data.get('user_fingerprint', 'unknown'), + 'client_ip': session_data.get('client_ip', 'unknown'), + 'user_agent': session_data.get('user_agent', 'unknown'), + 'created_at': session_data.get('created_at'), + 'last_activity': session_data.get('last_activity'), + 'status': session_data.get('status'), + 'scanner_status': scanner.status if scanner else 'unknown', + 'current_target': scanner.current_target if scanner else None, + 'session_age_minutes': round((time.time() - session_data.get('created_at', time.time())) / 60, 1) + } + except Exception as e: + print(f"ERROR: Failed to get session info for {session_id}: {e}") + return {'error': f'Failed to get session info: {str(e)}'} + + # Global session manager instance session_manager = SessionManager(session_timeout_minutes=60) \ No newline at end of file diff --git a/core/task_manager.py b/core/task_manager.py new file mode 100644 index 0000000..b9c0d7f --- /dev/null +++ b/core/task_manager.py @@ -0,0 +1,564 @@ +# dnsrecon/core/task_manager.py + +import threading +import time +import uuid +from enum import Enum +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any, Set +from datetime import datetime, timezone, timedelta +from collections import deque + +from utils.helpers import _is_valid_ip, _is_valid_domain + + +class TaskStatus(Enum): + """Enumeration of task execution statuses.""" + PENDING = "pending" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED_RETRYING = "failed_retrying" + FAILED_PERMANENT = "failed_permanent" + CANCELLED = "cancelled" + + +class TaskType(Enum): + """Enumeration of task types for provider queries.""" + DOMAIN_QUERY = "domain_query" + IP_QUERY = "ip_query" + GRAPH_UPDATE = "graph_update" + + +@dataclass +class TaskResult: + """Result of a task execution.""" + success: bool + data: Optional[Any] = None + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ReconTask: + """Represents a single reconnaissance task with retry logic.""" + task_id: str + task_type: TaskType + target: str + provider_name: str + depth: int + status: TaskStatus = TaskStatus.PENDING + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + # Retry configuration + max_retries: int = 3 + current_retry: int = 0 + base_backoff_seconds: float = 1.0 + max_backoff_seconds: float = 60.0 + + # Execution tracking + last_attempt_at: Optional[datetime] = None + next_retry_at: Optional[datetime] = None + execution_history: List[Dict[str, Any]] = field(default_factory=list) + + # Results + result: Optional[TaskResult] = None + + def __post_init__(self): + """Initialize additional fields after creation.""" + if not self.task_id: + self.task_id = str(uuid.uuid4())[:8] + + def calculate_next_retry_time(self) -> datetime: + """Calculate next retry time with exponential backoff and jitter.""" + if self.current_retry >= self.max_retries: + return None + + # Exponential backoff with jitter + backoff_time = min( + self.max_backoff_seconds, + self.base_backoff_seconds * (2 ** self.current_retry) + ) + + # Add jitter (±25%) + jitter = backoff_time * 0.25 * (0.5 - hash(self.task_id) % 1000 / 1000.0) + final_backoff = max(self.base_backoff_seconds, backoff_time + jitter) + + return datetime.now(timezone.utc) + timedelta(seconds=final_backoff) + + def should_retry(self) -> bool: + """Determine if task should be retried based on status and retry count.""" + if self.status != TaskStatus.FAILED_RETRYING: + return False + if self.current_retry >= self.max_retries: + return False + if self.next_retry_at and datetime.now(timezone.utc) < self.next_retry_at: + return False + return True + + def mark_failed(self, error: str, metadata: Dict[str, Any] = None): + """Mark task as failed and prepare for retry or permanent failure.""" + self.current_retry += 1 + self.last_attempt_at = datetime.now(timezone.utc) + + # Record execution history + execution_record = { + 'attempt': self.current_retry, + 'timestamp': self.last_attempt_at.isoformat(), + 'error': error, + 'metadata': metadata or {} + } + self.execution_history.append(execution_record) + + if self.current_retry >= self.max_retries: + self.status = TaskStatus.FAILED_PERMANENT + self.result = TaskResult(success=False, error=f"Permanent failure after {self.max_retries} attempts: {error}") + else: + self.status = TaskStatus.FAILED_RETRYING + self.next_retry_at = self.calculate_next_retry_time() + + def mark_succeeded(self, data: Any = None, metadata: Dict[str, Any] = None): + """Mark task as successfully completed.""" + self.status = TaskStatus.SUCCEEDED + self.last_attempt_at = datetime.now(timezone.utc) + self.result = TaskResult(success=True, data=data, metadata=metadata or {}) + + # Record successful execution + execution_record = { + 'attempt': self.current_retry + 1, + 'timestamp': self.last_attempt_at.isoformat(), + 'success': True, + 'metadata': metadata or {} + } + self.execution_history.append(execution_record) + + def get_summary(self) -> Dict[str, Any]: + """Get task summary for progress reporting.""" + return { + 'task_id': self.task_id, + 'task_type': self.task_type.value, + 'target': self.target, + 'provider': self.provider_name, + 'status': self.status.value, + 'current_retry': self.current_retry, + 'max_retries': self.max_retries, + 'created_at': self.created_at.isoformat(), + 'last_attempt_at': self.last_attempt_at.isoformat() if self.last_attempt_at else None, + 'next_retry_at': self.next_retry_at.isoformat() if self.next_retry_at else None, + 'total_attempts': len(self.execution_history), + 'has_result': self.result is not None + } + + +class TaskQueue: + """Thread-safe task queue with retry logic and priority handling.""" + + def __init__(self, max_concurrent_tasks: int = 5): + """Initialize task queue.""" + self.max_concurrent_tasks = max_concurrent_tasks + self.tasks: Dict[str, ReconTask] = {} + self.pending_queue = deque() + self.retry_queue = deque() + self.running_tasks: Set[str] = set() + + self._lock = threading.Lock() + self._stop_event = threading.Event() + + def __getstate__(self): + """Prepare TaskQueue for pickling by excluding unpicklable objects.""" + state = self.__dict__.copy() + # Exclude the unpickleable '_lock' and '_stop_event' attributes + if '_lock' in state: + del state['_lock'] + if '_stop_event' in state: + del state['_stop_event'] + return state + + def __setstate__(self, state): + """Restore TaskQueue after unpickling by reconstructing threading objects.""" + self.__dict__.update(state) + # Re-initialize the '_lock' and '_stop_event' attributes + self._lock = threading.Lock() + self._stop_event = threading.Event() + + def add_task(self, task: ReconTask) -> str: + """Add task to queue.""" + with self._lock: + self.tasks[task.task_id] = task + self.pending_queue.append(task.task_id) + print(f"Added task {task.task_id}: {task.provider_name} query for {task.target}") + return task.task_id + + def get_next_ready_task(self) -> Optional[ReconTask]: + """Get next task ready for execution.""" + with self._lock: + # Check if we have room for more concurrent tasks + if len(self.running_tasks) >= self.max_concurrent_tasks: + return None + + # First priority: retry queue (tasks ready for retry) + while self.retry_queue: + task_id = self.retry_queue.popleft() + if task_id in self.tasks: + task = self.tasks[task_id] + if task.should_retry(): + task.status = TaskStatus.RUNNING + self.running_tasks.add(task_id) + print(f"Retrying task {task_id} (attempt {task.current_retry + 1})") + return task + + # Second priority: pending queue (new tasks) + while self.pending_queue: + task_id = self.pending_queue.popleft() + if task_id in self.tasks: + task = self.tasks[task_id] + if task.status == TaskStatus.PENDING: + task.status = TaskStatus.RUNNING + self.running_tasks.add(task_id) + print(f"Starting task {task_id}") + return task + + return None + + def complete_task(self, task_id: str, success: bool, data: Any = None, + error: str = None, metadata: Dict[str, Any] = None): + """Mark task as completed (success or failure).""" + with self._lock: + if task_id not in self.tasks: + return + + task = self.tasks[task_id] + self.running_tasks.discard(task_id) + + if success: + task.mark_succeeded(data=data, metadata=metadata) + print(f"Task {task_id} succeeded") + else: + task.mark_failed(error or "Unknown error", metadata=metadata) + if task.status == TaskStatus.FAILED_RETRYING: + self.retry_queue.append(task_id) + print(f"Task {task_id} failed, scheduled for retry at {task.next_retry_at}") + else: + print(f"Task {task_id} permanently failed after {task.current_retry} attempts") + + def cancel_all_tasks(self): + """Cancel all pending and running tasks.""" + with self._lock: + self._stop_event.set() + for task in self.tasks.values(): + if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.FAILED_RETRYING]: + task.status = TaskStatus.CANCELLED + self.pending_queue.clear() + self.retry_queue.clear() + self.running_tasks.clear() + print("All tasks cancelled") + + def is_complete(self) -> bool: + """Check if all tasks are complete (succeeded, permanently failed, or cancelled).""" + with self._lock: + for task in self.tasks.values(): + if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.FAILED_RETRYING]: + return False + return True + + def get_statistics(self) -> Dict[str, Any]: + """Get queue statistics.""" + with self._lock: + stats = { + 'total_tasks': len(self.tasks), + 'pending': len(self.pending_queue), + 'running': len(self.running_tasks), + 'retry_queue': len(self.retry_queue), + 'succeeded': 0, + 'failed_permanent': 0, + 'cancelled': 0, + 'failed_retrying': 0 + } + + for task in self.tasks.values(): + if task.status == TaskStatus.SUCCEEDED: + stats['succeeded'] += 1 + elif task.status == TaskStatus.FAILED_PERMANENT: + stats['failed_permanent'] += 1 + elif task.status == TaskStatus.CANCELLED: + stats['cancelled'] += 1 + elif task.status == TaskStatus.FAILED_RETRYING: + stats['failed_retrying'] += 1 + + stats['completion_rate'] = (stats['succeeded'] / stats['total_tasks'] * 100) if stats['total_tasks'] > 0 else 0 + stats['is_complete'] = self.is_complete() + + return stats + + def get_task_summaries(self) -> List[Dict[str, Any]]: + """Get summaries of all tasks for detailed progress reporting.""" + with self._lock: + return [task.get_summary() for task in self.tasks.values()] + + def get_failed_tasks(self) -> List[ReconTask]: + """Get all permanently failed tasks for analysis.""" + with self._lock: + return [task for task in self.tasks.values() if task.status == TaskStatus.FAILED_PERMANENT] + + +class TaskExecutor: + """Executes reconnaissance tasks using providers.""" + + def __init__(self, providers: List, graph_manager, logger): + """Initialize task executor.""" + self.providers = {provider.get_name(): provider for provider in providers} + self.graph = graph_manager + self.logger = logger + + def execute_task(self, task: ReconTask) -> TaskResult: + """ + Execute a single reconnaissance task. + + Args: + task: Task to execute + + Returns: + TaskResult with success/failure information + """ + try: + print(f"Executing task {task.task_id}: {task.provider_name} query for {task.target}") + + provider = self.providers.get(task.provider_name) + if not provider: + return TaskResult( + success=False, + error=f"Provider {task.provider_name} not available" + ) + + if not provider.is_available(): + return TaskResult( + success=False, + error=f"Provider {task.provider_name} is not available (missing API key or configuration)" + ) + + # Execute provider query based on task type + if task.task_type == TaskType.DOMAIN_QUERY: + if not _is_valid_domain(task.target): + return TaskResult(success=False, error=f"Invalid domain: {task.target}") + + relationships = provider.query_domain(task.target) + + elif task.task_type == TaskType.IP_QUERY: + if not _is_valid_ip(task.target): + return TaskResult(success=False, error=f"Invalid IP: {task.target}") + + relationships = provider.query_ip(task.target) + + else: + return TaskResult(success=False, error=f"Unsupported task type: {task.task_type}") + + # Process results and update graph + new_targets = set() + relationships_added = 0 + + for source, target, rel_type, confidence, raw_data in relationships: + # Add nodes to graph + from core.graph_manager import NodeType + + if _is_valid_ip(target): + self.graph.add_node(target, NodeType.IP) + new_targets.add(target) + elif target.startswith('AS') and target[2:].isdigit(): + self.graph.add_node(target, NodeType.ASN) + elif _is_valid_domain(target): + self.graph.add_node(target, NodeType.DOMAIN) + new_targets.add(target) + + # Add edge to graph + if self.graph.add_edge(source, target, rel_type, confidence, task.provider_name, raw_data): + relationships_added += 1 + + # Log forensic information + self.logger.logger.info( + f"Task {task.task_id} completed: {len(relationships)} relationships found, " + f"{relationships_added} added to graph, {len(new_targets)} new targets" + ) + + return TaskResult( + success=True, + data={ + 'relationships': relationships, + 'new_targets': list(new_targets), + 'relationships_added': relationships_added + }, + metadata={ + 'provider': task.provider_name, + 'target': task.target, + 'depth': task.depth, + 'execution_time': datetime.now(timezone.utc).isoformat() + } + ) + + except Exception as e: + error_msg = f"Task execution failed: {str(e)}" + print(f"ERROR: {error_msg} for task {task.task_id}") + self.logger.logger.error(error_msg) + + return TaskResult( + success=False, + error=error_msg, + metadata={ + 'provider': task.provider_name, + 'target': task.target, + 'exception_type': type(e).__name__ + } + ) + + +class TaskManager: + """High-level task management for reconnaissance scans.""" + + def __init__(self, providers: List, graph_manager, logger, max_concurrent_tasks: int = 5): + """Initialize task manager.""" + self.task_queue = TaskQueue(max_concurrent_tasks) + self.task_executor = TaskExecutor(providers, graph_manager, logger) + self.logger = logger + + # Execution control + self._stop_event = threading.Event() + self._execution_threads: List[threading.Thread] = [] + self._is_running = False + + def create_provider_tasks(self, target: str, depth: int, providers: List) -> List[str]: + """ + Create tasks for querying all eligible providers for a target. + + Args: + target: Domain or IP to query + depth: Current recursion depth + providers: List of available providers + + Returns: + List of created task IDs + """ + task_ids = [] + is_ip = _is_valid_ip(target) + target_key = 'ips' if is_ip else 'domains' + task_type = TaskType.IP_QUERY if is_ip else TaskType.DOMAIN_QUERY + + for provider in providers: + if provider.get_eligibility().get(target_key) and provider.is_available(): + task = ReconTask( + task_id=str(uuid.uuid4())[:8], + task_type=task_type, + target=target, + provider_name=provider.get_name(), + depth=depth, + max_retries=3 # Configure retries per task type/provider + ) + + task_id = self.task_queue.add_task(task) + task_ids.append(task_id) + + return task_ids + + def start_execution(self, max_workers: int = 3): + """Start task execution with specified number of worker threads.""" + if self._is_running: + print("Task execution already running") + return + + self._is_running = True + self._stop_event.clear() + + print(f"Starting task execution with {max_workers} workers") + + for i in range(max_workers): + worker_thread = threading.Thread( + target=self._worker_loop, + name=f"TaskWorker-{i+1}", + daemon=True + ) + worker_thread.start() + self._execution_threads.append(worker_thread) + + def stop_execution(self): + """Stop task execution and cancel all tasks.""" + print("Stopping task execution") + self._stop_event.set() + self.task_queue.cancel_all_tasks() + self._is_running = False + + # Wait for worker threads to finish + for thread in self._execution_threads: + thread.join(timeout=5.0) + + self._execution_threads.clear() + print("Task execution stopped") + + def _worker_loop(self): + """Worker thread loop for executing tasks.""" + thread_name = threading.current_thread().name + print(f"{thread_name} started") + + while not self._stop_event.is_set(): + try: + # Get next task to execute + task = self.task_queue.get_next_ready_task() + + if task is None: + # No tasks ready, check if we should exit + if self.task_queue.is_complete() or self._stop_event.is_set(): + break + time.sleep(0.1) # Brief sleep before checking again + continue + + # Execute the task + result = self.task_executor.execute_task(task) + + # Complete the task in queue + self.task_queue.complete_task( + task.task_id, + success=result.success, + data=result.data, + error=result.error, + metadata=result.metadata + ) + + except Exception as e: + print(f"ERROR: Worker {thread_name} encountered error: {e}") + # Continue running even if individual task fails + continue + + print(f"{thread_name} finished") + + def wait_for_completion(self, timeout_seconds: int = 300) -> bool: + """ + Wait for all tasks to complete. + + Args: + timeout_seconds: Maximum time to wait + + Returns: + True if all tasks completed, False if timeout + """ + start_time = time.time() + + while time.time() - start_time < timeout_seconds: + if self.task_queue.is_complete(): + return True + + if self._stop_event.is_set(): + return False + + time.sleep(1.0) # Check every second + + print(f"Timeout waiting for task completion after {timeout_seconds} seconds") + return False + + def get_progress_report(self) -> Dict[str, Any]: + """Get detailed progress report for UI updates.""" + stats = self.task_queue.get_statistics() + failed_tasks = self.task_queue.get_failed_tasks() + + return { + 'statistics': stats, + 'failed_tasks': [task.get_summary() for task in failed_tasks], + 'is_running': self._is_running, + 'worker_count': len(self._execution_threads), + 'detailed_tasks': self.task_queue.get_task_summaries() if stats['total_tasks'] < 50 else [] # Limit detail for performance + } \ No newline at end of file diff --git a/providers/base_provider.py b/providers/base_provider.py index 4d0f8e1..3627af3 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -5,14 +5,16 @@ import requests import threading import os import json +import hashlib from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime, timezone from core.logger import get_forensic_logger class RateLimiter: - """Simple rate limiter for API calls.""" + """Thread-safe rate limiter for API calls.""" def __init__(self, requests_per_minute: int): """ @@ -24,36 +26,152 @@ class RateLimiter: self.requests_per_minute = requests_per_minute self.min_interval = 60.0 / requests_per_minute self.last_request_time = 0 + self._lock = threading.Lock() def __getstate__(self): """RateLimiter is fully picklable, return full state.""" - return self.__dict__.copy() + state = self.__dict__.copy() + # Exclude unpickleable lock + if '_lock' in state: + del state['_lock'] + return state def __setstate__(self, state): """Restore RateLimiter state.""" self.__dict__.update(state) + self._lock = threading.Lock() def wait_if_needed(self) -> None: """Wait if necessary to respect rate limits.""" - current_time = time.time() - time_since_last = current_time - self.last_request_time + with self._lock: + current_time = time.time() + time_since_last = current_time - self.last_request_time - if time_since_last < self.min_interval: - sleep_time = self.min_interval - time_since_last - time.sleep(sleep_time) + if time_since_last < self.min_interval: + sleep_time = self.min_interval - time_since_last + time.sleep(sleep_time) - self.last_request_time = time.time() + self.last_request_time = time.time() + + +class ProviderCache: + """Thread-safe global cache for provider queries.""" + + def __init__(self, provider_name: str, cache_expiry_hours: int = 12): + """ + Initialize provider-specific cache. + + Args: + provider_name: Name of the provider for cache directory + cache_expiry_hours: Cache expiry time in hours + """ + self.provider_name = provider_name + self.cache_expiry = cache_expiry_hours * 3600 # Convert to seconds + self.cache_dir = os.path.join('.cache', provider_name) + self._lock = threading.Lock() + + # Ensure cache directory exists with thread-safe creation + os.makedirs(self.cache_dir, exist_ok=True) + + def _generate_cache_key(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> str: + """Generate unique cache key for request.""" + cache_data = f"{method}:{url}:{json.dumps(params or {}, sort_keys=True)}" + return hashlib.md5(cache_data.encode()).hexdigest() + ".json" + + def get_cached_response(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> Optional[requests.Response]: + """ + Retrieve cached response if available and not expired. + + Returns: + Cached Response object or None if cache miss/expired + """ + cache_key = self._generate_cache_key(method, url, params) + cache_path = os.path.join(self.cache_dir, cache_key) + + with self._lock: + if not os.path.exists(cache_path): + return None + + # Check if cache is expired + cache_age = time.time() - os.path.getmtime(cache_path) + if cache_age >= self.cache_expiry: + try: + os.remove(cache_path) + except OSError: + pass # File might have been removed by another thread + return None + + try: + with open(cache_path, 'r', encoding='utf-8') as f: + cached_data = json.load(f) + + # Reconstruct Response object + response = requests.Response() + response.status_code = cached_data['status_code'] + response._content = cached_data['content'].encode('utf-8') + response.headers.update(cached_data['headers']) + + return response + + except (json.JSONDecodeError, KeyError, IOError) as e: + # Cache file corrupted, remove it + try: + os.remove(cache_path) + except OSError: + pass + return None + + def cache_response(self, method: str, url: str, params: Optional[Dict[str, Any]], + response: requests.Response) -> bool: + """ + Cache successful response to disk. + + Returns: + True if cached successfully, False otherwise + """ + if response.status_code != 200: + return False + + cache_key = self._generate_cache_key(method, url, params) + cache_path = os.path.join(self.cache_dir, cache_key) + + with self._lock: + try: + cache_data = { + 'status_code': response.status_code, + 'content': response.text, + 'headers': dict(response.headers), + 'cached_at': datetime.now(timezone.utc).isoformat() + } + + # Write to temporary file first, then rename for atomic operation + temp_path = cache_path + '.tmp' + with open(temp_path, 'w', encoding='utf-8') as f: + json.dump(cache_data, f) + + # Atomic rename to prevent partial cache files + os.rename(temp_path, cache_path) + return True + + except (IOError, OSError) as e: + # Clean up temp file if it exists + try: + if os.path.exists(temp_path): + os.remove(temp_path) + except OSError: + pass + return False class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. - Now supports session-specific configuration. + Now supports global provider-specific caching and session-specific configuration. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): """ - Initialize base provider with session-specific configuration. + Initialize base provider with global caching and session-specific configuration. Args: name: Provider name for logging @@ -80,28 +198,25 @@ class BaseProvider(ABC): self.logger = get_forensic_logger() self._stop_event = None - # Caching configuration (per session) - self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config - self.cache_expiry = 12 * 3600 # 12 hours in seconds - if not os.path.exists(self.cache_dir): - os.makedirs(self.cache_dir) + # GLOBAL provider-specific caching (not session-based) + self.cache = ProviderCache(name, cache_expiry_hours=12) # Statistics (per provider instance) self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 + self.cache_hits = 0 + self.cache_misses = 0 - print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)") + print(f"Initialized {name} provider with global cache and session config (rate: {actual_rate_limit}/min)") def __getstate__(self): """Prepare BaseProvider for pickling by excluding unpicklable objects.""" state = self.__dict__.copy() # Exclude the unpickleable '_local' attribute and stop event - unpicklable_attrs = ['_local', '_stop_event'] - for attr in unpicklable_attrs: - if attr in state: - del state[attr] + state['_local'] = None + state['_stop_event'] = None return state def __setstate__(self, state): @@ -116,7 +231,7 @@ class BaseProvider(ABC): if not hasattr(self._local, 'session'): self._local.session = requests.Session() self._local.session.headers.update({ - 'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)' + 'User-Agent': 'DNSRecon/2.0 (Passive Reconnaissance Tool)' }) return self._local.session @@ -177,37 +292,28 @@ class BaseProvider(ABC): target_indicator: str = "", max_retries: int = 3) -> Optional[requests.Response]: """ - Make a rate-limited HTTP request with aggressive stop signal handling. - Terminates immediately when stop is requested, including during retries. + Make a rate-limited HTTP request with global caching and aggressive stop signal handling. """ # Check for cancellation before starting if self._is_stop_requested(): print(f"Request cancelled before start: {url}") return None - # Create a unique cache key - cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" - cache_path = os.path.join(self.cache_dir, cache_key) - - # Check cache - if os.path.exists(cache_path): - cache_age = time.time() - os.path.getmtime(cache_path) - if cache_age < self.cache_expiry: - print(f"Returning cached response for: {url}") - with open(cache_path, 'r') as f: - cached_data = json.load(f) - response = requests.Response() - response.status_code = cached_data['status_code'] - response._content = cached_data['content'].encode('utf-8') - response.headers = cached_data['headers'] - return response + # Check global cache first + cached_response = self.cache.get_cached_response(method, url, params) + if cached_response is not None: + print(f"Cache hit for {self.name}: {url}") + self.cache_hits += 1 + return cached_response + + self.cache_misses += 1 # Determine effective max_retries based on stop signal effective_max_retries = 0 if self._is_stop_requested() else max_retries last_exception = None for attempt in range(effective_max_retries + 1): - # AGGRESSIVE: Check for cancellation before each attempt + # Check for cancellation before each attempt if self._is_stop_requested(): print(f"Request cancelled during attempt {attempt + 1}: {url}") return None @@ -217,7 +323,7 @@ class BaseProvider(ABC): print(f"Request cancelled during rate limiting: {url}") return None - # AGGRESSIVE: Final check before making HTTP request + # Final check before making HTTP request if self._is_stop_requested(): print(f"Request cancelled before HTTP call: {url}") return None @@ -236,11 +342,8 @@ class BaseProvider(ABC): print(f"Making {method} request to: {url} (attempt {attempt + 1})") - # AGGRESSIVE: Use much shorter timeout if termination is requested - request_timeout = self.timeout - if self._is_stop_requested(): - request_timeout = 2 # Max 2 seconds if termination requested - print(f"Stop requested - using short timeout: {request_timeout}s") + # Use shorter timeout if termination is requested + request_timeout = 2 if self._is_stop_requested() else self.timeout # Make request if method.upper() == "GET": @@ -276,13 +379,9 @@ class BaseProvider(ABC): error=None, target_indicator=target_indicator ) - # Cache the successful response to disk - with open(cache_path, 'w') as f: - json.dump({ - 'status_code': response.status_code, - 'content': response.text, - 'headers': dict(response.headers) - }, f) + + # Cache the successful response globally + self.cache.cache_response(method, url, params, response) return response except requests.exceptions.RequestException as e: @@ -291,23 +390,21 @@ class BaseProvider(ABC): print(f"Request failed (attempt {attempt + 1}): {error}") last_exception = e - # AGGRESSIVE: Immediately abort retries if stop requested + # Immediately abort retries if stop requested if self._is_stop_requested(): print(f"Stop requested - aborting retries for: {url}") break - # Check if we should retry (but only if stop not requested) + # Check if we should retry if attempt < effective_max_retries and self._should_retry(e): - # Use a longer, more respectful backoff for 429 errors + # Exponential backoff with jitter for 429 errors if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: - # Start with a 10-second backoff and increase exponentially - backoff_time = 10 * (2 ** attempt) + backoff_time = min(60, 10 * (2 ** attempt)) print(f"Rate limit hit. Retrying in {backoff_time} seconds...") else: - backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors + backoff_time = min(2.0, (2 ** attempt) * 0.5) print(f"Retrying in {backoff_time} seconds...") - # AGGRESSIVE: Much shorter backoff and more frequent checking if not self._sleep_with_cancellation_check(backoff_time): print(f"Stop requested during backoff - aborting: {url}") return None @@ -348,7 +445,6 @@ class BaseProvider(ABC): return True return False - def _wait_with_cancellation_check(self) -> bool: """ Wait for rate limiting while aggressively checking for cancellation. @@ -447,7 +543,7 @@ class BaseProvider(ABC): def get_statistics(self) -> Dict[str, Any]: """ - Get provider statistics. + Get provider statistics including cache performance. Returns: Dictionary containing provider performance metrics @@ -459,5 +555,8 @@ class BaseProvider(ABC): 'failed_requests': self.failed_requests, 'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0, 'relationships_found': self.total_relationships_found, - 'rate_limit': self.rate_limiter.requests_per_minute + 'rate_limit': self.rate_limiter.requests_per_minute, + 'cache_hits': self.cache_hits, + 'cache_misses': self.cache_misses, + 'cache_hit_rate': (self.cache_hits / (self.cache_hits + self.cache_misses) * 100) if (self.cache_hits + self.cache_misses) > 0 else 0 } \ No newline at end of file