From 469c133f1bc1f0b1bae4dbd70d978dae562ca86e Mon Sep 17 00:00:00 2001 From: overcuriousity Date: Wed, 17 Sep 2025 11:18:06 +0200 Subject: [PATCH] fix session handling --- app.py | 92 ++++++++++-------------- core/scanner.py | 27 ++++--- core/session_manager.py | 152 +++++++++++++++++++++++++++++----------- 3 files changed, 167 insertions(+), 104 deletions(-) diff --git a/app.py b/app.py index 1955e36..450337c 100644 --- a/app.py +++ b/app.py @@ -24,50 +24,28 @@ app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanen def get_user_scanner(): """ - Retrieves the scanner for the current session, or creates a new - session and scanner if one doesn't exist. + Retrieves the scanner for the current session, or creates a new session only if none exists. """ - # Get current Flask session info for debugging current_flask_session_id = session.get('dnsrecon_session_id') - # Try to get existing session + # Try to get existing session first if current_flask_session_id: existing_scanner = session_manager.get_session(current_flask_session_id) if existing_scanner: + print(f"Reusing existing session: {current_flask_session_id}") return current_flask_session_id, existing_scanner else: - print(f"Session {current_flask_session_id} not found in Redis, checking for active sessions...") + print(f"Session {current_flask_session_id} expired, will create new one") - # This prevents creating duplicate sessions when Flask session is lost but Redis session exists - stats = session_manager.get_statistics() - if stats['running_scans'] > 0: - # Get all session keys and find running ones - try: - import redis - redis_client = redis.StrictRedis(db=0, decode_responses=False) - session_keys = redis_client.keys("dnsrecon:session:*") - - for session_key in session_keys: - session_id = session_key.decode('utf-8').split(':')[-1] - scanner = session_manager.get_session(session_id) - if scanner and scanner.status in ['running', 'completed']: - print(f"Reusing active session: {session_id}") - # Update Flask session to match - session['dnsrecon_session_id'] = session_id - session.permanent = True - return session_id, scanner - except Exception as e: - print(f"Error finding active session: {e}") - - # Create new session if none exists - print("Creating new session as none was found...") + # Only create new session if we absolutely don't have one + print("Creating new session (no valid session found)") new_session_id = session_manager.create_session() new_scanner = session_manager.get_session(new_session_id) if not new_scanner: raise Exception("Failed to create new scanner session") - # Store in Flask session with explicit settings + # Store in Flask session session['dnsrecon_session_id'] = new_session_id session.permanent = True @@ -83,8 +61,8 @@ def index(): @app.route('/api/scan/start', methods=['POST']) def start_scan(): """ - Start a new reconnaissance scan. Creates a new isolated scanner if - clear_graph is true, otherwise adds to the existing one. + FIXED: Start a new reconnaissance scan while preserving session configuration. + Only clears graph data, not the entire session with API keys. """ print("=== API: /api/scan/start called ===") @@ -96,7 +74,7 @@ def start_scan(): target = data['target'].strip() max_depth = data.get('max_depth', config.default_recursion_depth) clear_graph = data.get('clear_graph', True) - force_rescan_target = data.get('force_rescan_target', None) # **FIX**: Get the new parameter + force_rescan_target = data.get('force_rescan_target', None) print(f"Parsed - target: '{target}', max_depth: {max_depth}, clear_graph: {clear_graph}, force_rescan: {force_rescan_target}") @@ -108,28 +86,17 @@ def start_scan(): if not isinstance(max_depth, int) or not 1 <= max_depth <= 5: return jsonify({'success': False, 'error': 'Max depth must be an integer between 1 and 5'}), 400 - user_session_id, scanner = None, None - - if clear_graph: - print("Clear graph requested: Creating a new, isolated scanner session.") - old_session_id = session.get('dnsrecon_session_id') - if old_session_id: - session_manager.terminate_session(old_session_id) - - user_session_id = session_manager.create_session() - session['dnsrecon_session_id'] = user_session_id - session.permanent = True - scanner = session_manager.get_session(user_session_id) - else: - print("Adding to existing graph: Reusing the current scanner session.") - user_session_id, scanner = get_user_scanner() - + # FIXED: Always reuse existing session, preserve API keys + user_session_id, scanner = get_user_scanner() + if not scanner: - return jsonify({'success': False, 'error': 'Failed to get or create a scanner instance.'}), 500 + return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500 print(f"Using scanner {id(scanner)} in session {user_session_id}") + print(f"Scanner has {len(scanner.providers)} providers: {[p.get_name() for p in scanner.providers]}") - success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target) # **FIX**: Pass the new parameter + # FIXED: Pass clear_graph flag to scanner, let it handle graph clearing internally + success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target) if success: return jsonify({ @@ -137,6 +104,7 @@ def start_scan(): 'message': 'Scan started successfully', 'scan_id': scanner.logger.session_id, 'user_session_id': user_session_id, + 'available_providers': [p.get_name() for p in scanner.providers] # Show which providers are active }) else: return jsonify({ @@ -148,7 +116,7 @@ def start_scan(): print(f"ERROR: Exception in start_scan endpoint: {e}") traceback.print_exc() return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 - + @app.route('/api/scan/stop', methods=['POST']) def stop_scan(): """Stop the current scan with immediate GUI feedback.""" @@ -237,7 +205,9 @@ def get_scan_status(): status['debug_info'] = { 'scanner_object_id': id(scanner), 'session_id_set': bool(scanner.session_id), - 'has_scan_thread': bool(scanner.scan_thread and scanner.scan_thread.is_alive()) + 'has_scan_thread': bool(scanner.scan_thread and scanner.scan_thread.is_alive()), + 'provider_count': len(scanner.providers), + 'provider_names': [p.get_name() for p in scanner.providers] } return jsonify({ @@ -499,6 +469,7 @@ def get_providers(): completed_tasks = scanner.indicators_completed total_tasks = scanner.total_tasks_ever_enqueued print(f"DEBUG: Task Progress - Completed: {completed_tasks}, Total Enqueued: {total_tasks}") + print(f"DEBUG: Scanner has {len(scanner.providers)} providers: {[p.get_name() for p in scanner.providers]}") else: print("DEBUG: No active scanner session found.") @@ -522,7 +493,6 @@ def get_providers(): @app.route('/api/config/api-keys', methods=['POST']) def set_api_keys(): """ - Set API keys for providers for the user session only. """ try: data = request.get_json() @@ -537,6 +507,8 @@ def set_api_keys(): user_session_id, scanner = get_user_scanner() session_config = scanner.config + print(f"Setting API keys for session {user_session_id}: {list(data.keys())}") + updated_providers = [] # Iterate over the API keys provided in the request data @@ -548,10 +520,17 @@ def set_api_keys(): if success: updated_providers.append(provider_name) + print(f"API key {'set' if api_key_value else 'cleared'} for {provider_name}") if updated_providers: - # Reinitialize scanner providers to apply the new keys + # FIXED: Reinitialize scanner providers to apply the new keys + print("Reinitializing providers with new API keys...") + old_provider_count = len(scanner.providers) scanner._initialize_providers() + new_provider_count = len(scanner.providers) + + print(f"Providers reinitialized: {old_provider_count} -> {new_provider_count}") + print(f"Available providers: {[p.get_name() for p in scanner.providers]}") # Persist the updated scanner object back to the user's session session_manager.update_session_scanner(user_session_id, scanner) @@ -560,7 +539,8 @@ def set_api_keys(): 'success': True, 'message': f'API keys updated for session {user_session_id}: {", ".join(updated_providers)}', 'updated_providers': updated_providers, - 'user_session_id': user_session_id + 'user_session_id': user_session_id, + 'available_providers': [p.get_name() for p in scanner.providers] }) else: return jsonify({ @@ -597,7 +577,7 @@ def internal_error(error): if __name__ == '__main__': - print("Starting DNSRecon Flask application with user session support...") + print("Starting DNSRecon Flask application with streamlined session management...") # Load configuration from environment config.load_from_env() diff --git a/core/scanner.py b/core/scanner.py index a3bdb7b..735ddf3 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -30,7 +30,7 @@ class ScanStatus: class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. - Now provider-agnostic, consuming standardized ProviderResult objects. + FIXED: Now preserves session configuration including API keys when clearing graphs. """ def __init__(self, session_config=None): @@ -220,13 +220,18 @@ class Scanner: print("Session configuration updated") def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: - """Start a new reconnaissance scan with proper cleanup of previous scans.""" + """ + FIXED: Start a new reconnaissance scan preserving session configuration. + Only clears graph data when requested, never destroys session/API keys. + """ print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"Session ID: {self.session_id}") print(f"Initial scanner status: {self.status}") + print(f"Clear graph requested: {clear_graph}") + print(f"Current providers: {[p.get_name() for p in self.providers]}") self.total_tasks_ever_enqueued = 0 - # **IMPROVED**: More aggressive cleanup of previous scan + # FIXED: Improved cleanup of previous scan without destroying session config if self.scan_thread and self.scan_thread.is_alive(): print("A previous scan thread is still alive. Forcing termination...") @@ -251,15 +256,14 @@ class Scanner: if self.scan_thread.is_alive(): print("WARNING: Previous scan thread is still alive after 5 seconds") - # Continue anyway, but log the issue self.logger.logger.warning("Previous scan thread failed to terminate cleanly") - # Reset state for new scan with proper forensic logging - print("Resetting scanner state for new scan...") + # FIXED: Reset scan state but preserve session configuration (API keys, etc.) + print("Resetting scanner state for new scan (preserving session config)...") self.status = ScanStatus.IDLE self.stop_event.clear() - # **NEW**: Clear Redis stop signal explicitly + # Clear Redis stop signal explicitly if self.session_id: from core.session_manager import session_manager session_manager.clear_stop_signal(self.session_id) @@ -267,13 +271,14 @@ class Scanner: with self.processing_lock: self.currently_processing.clear() + # Reset scan-specific state but keep providers and config intact self.task_queue = PriorityQueue() self.target_retries.clear() self.scan_failed_due_to_retries = False # Update session state immediately for GUI feedback self._update_session_state() - print("Scanner state reset complete.") + print("Scanner state reset complete (providers preserved).") try: if not hasattr(self, 'providers') or not self.providers: @@ -282,9 +287,12 @@ class Scanner: print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}") + # FIXED: Only clear graph if explicitly requested, don't destroy session if clear_graph: + print("Clearing graph data (preserving session configuration)") self.graph.clear() + # Handle force rescan by clearing provider states for that specific node if force_rescan_target and self.graph.graph.has_node(force_rescan_target): print(f"Forcing rescan of {force_rescan_target}, clearing provider states.") node_data = self.graph.graph.nodes[force_rescan_target] @@ -304,7 +312,7 @@ class Scanner: # Update GUI with scan preparation state self._update_session_state() - # Start new forensic session + # Start new forensic session (but don't reinitialize providers) print(f"Starting new forensic session for scanner {id(self)}...") self.logger = new_session() @@ -318,6 +326,7 @@ class Scanner: self.scan_thread.start() print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") + print(f"Active providers for this scan: {[p.get_name() for p in self.providers]}") return True except Exception as e: diff --git a/core/session_manager.py b/core/session_manager.py index 4662c65..d44c9e8 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -12,7 +12,8 @@ from config import config class SessionManager: """ - Manages multiple scanner instances for concurrent user sessions using Redis. + FIXED: Manages multiple scanner instances for concurrent user sessions using Redis. + Now more conservative about session creation to preserve API keys and configuration. """ def __init__(self, session_timeout_minutes: int = 0): @@ -24,7 +25,10 @@ class SessionManager: self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.session_timeout = session_timeout_minutes * 60 # Convert to seconds - self.lock = threading.Lock() # Lock for local operations, Redis handles atomic ops + self.lock = threading.Lock() + + # FIXED: Add a creation lock to prevent race conditions + self.creation_lock = threading.Lock() # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) @@ -36,7 +40,7 @@ class SessionManager: """Prepare SessionManager for pickling.""" state = self.__dict__.copy() # Exclude unpickleable attributes - Redis client and threading objects - unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client'] + unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock'] for attr in unpicklable_attrs: if attr in state: del state[attr] @@ -49,6 +53,7 @@ class SessionManager: import redis self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.lock = threading.Lock() + self.creation_lock = threading.Lock() self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() @@ -62,44 +67,106 @@ class SessionManager: def create_session(self) -> str: """ - Create a new user session and store it in Redis. + FIXED: Create a new user session with thread-safe creation to prevent duplicates. """ - session_id = str(uuid.uuid4()) - print(f"=== CREATING SESSION {session_id} IN REDIS ===") - - try: - from core.session_config import create_session_config - session_config = create_session_config() - scanner_instance = Scanner(session_config=session_config) + # FIXED: Use creation lock to prevent race conditions + with self.creation_lock: + session_id = str(uuid.uuid4()) + print(f"=== CREATING SESSION {session_id} IN REDIS ===") - # Set the session ID on the scanner for cross-process stop signal management - scanner_instance.session_id = session_id + try: + from core.session_config import create_session_config + session_config = create_session_config() + scanner_instance = Scanner(session_config=session_config) + + # Set the session ID on the scanner for cross-process stop signal management + scanner_instance.session_id = session_id + + session_data = { + 'scanner': scanner_instance, + 'config': session_config, + 'created_at': time.time(), + 'last_activity': time.time(), + 'status': 'active' + } + + # Serialize the entire session data dictionary using pickle + serialized_data = pickle.dumps(session_data) + + # Store in Redis + session_key = self._get_session_key(session_id) + self.redis_client.setex(session_key, self.session_timeout, serialized_data) + + # Initialize stop signal as False + stop_key = self._get_stop_signal_key(session_id) + self.redis_client.setex(stop_key, self.session_timeout, b'0') + + print(f"Session {session_id} stored in Redis with stop signal initialized") + print(f"Session has {len(scanner_instance.providers)} providers: {[p.get_name() for p in scanner_instance.providers]}") + return session_id + + except Exception as e: + print(f"ERROR: Failed to create session {session_id}: {e}") + raise + + def clone_session_preserving_config(self, source_session_id: str) -> str: + """ + FIXED: Create a new session that preserves the configuration (including API keys) from an existing session. + This is used when we need a fresh scanner but want to keep user configuration. + """ + with self.creation_lock: + print(f"=== CLONING SESSION {source_session_id} (PRESERVING CONFIG) ===") - session_data = { - 'scanner': scanner_instance, - 'config': session_config, - 'created_at': time.time(), - 'last_activity': time.time(), - 'status': 'active' - } - - # Serialize the entire session data dictionary using pickle - serialized_data = pickle.dumps(session_data) - - # Store in Redis - session_key = self._get_session_key(session_id) - self.redis_client.setex(session_key, self.session_timeout, serialized_data) - - # Initialize stop signal as False - stop_key = self._get_stop_signal_key(session_id) - self.redis_client.setex(stop_key, self.session_timeout, b'0') - - print(f"Session {session_id} stored in Redis with stop signal initialized") - return session_id - - except Exception as e: - print(f"ERROR: Failed to create session {session_id}: {e}") - raise + try: + # Get the source session data + source_session_data = self._get_session_data(source_session_id) + if not source_session_data: + print(f"ERROR: Source session {source_session_id} not found for cloning") + return self.create_session() # Fallback to new session + + # Create new session ID + new_session_id = str(uuid.uuid4()) + + # Get the preserved configuration + preserved_config = source_session_data.get('config') + if not preserved_config: + print(f"WARNING: No config found in source session, creating new") + from core.session_config import create_session_config + preserved_config = create_session_config() + + print(f"Preserving config with API keys: {list(preserved_config.api_keys.keys())}") + + # Create new scanner with preserved config + new_scanner = Scanner(session_config=preserved_config) + new_scanner.session_id = new_session_id + + print(f"New scanner has {len(new_scanner.providers)} providers: {[p.get_name() for p in new_scanner.providers]}") + + new_session_data = { + 'scanner': new_scanner, + 'config': preserved_config, + 'created_at': time.time(), + 'last_activity': time.time(), + 'status': 'active', + 'cloned_from': source_session_id + } + + # Store in Redis + serialized_data = pickle.dumps(new_session_data) + session_key = self._get_session_key(new_session_id) + self.redis_client.setex(session_key, self.session_timeout, serialized_data) + + # Initialize stop signal + stop_key = self._get_stop_signal_key(new_session_id) + self.redis_client.setex(stop_key, self.session_timeout, b'0') + + print(f"Cloned session {new_session_id} with preserved configuration") + return new_session_id + + except Exception as e: + print(f"ERROR: Failed to clone session {source_session_id}: {e}") + # Fallback to creating a new session + return self.create_session() def set_stop_signal(self, session_id: str) -> bool: """ @@ -208,7 +275,14 @@ class SessionManager: # Immediately save to Redis for GUI updates success = self._save_session_data(session_id, session_data) if success: - print(f"Scanner state updated for session {session_id} (status: {scanner.status})") + # Only log occasionally to reduce noise + if hasattr(self, '_last_update_log'): + if time.time() - self._last_update_log > 5: # Log every 5 seconds max + print(f"Scanner state updated for session {session_id} (status: {scanner.status})") + self._last_update_log = time.time() + else: + print(f"Scanner state updated for session {session_id} (status: {scanner.status})") + self._last_update_log = time.time() else: print(f"WARNING: Failed to save scanner state for session {session_id}") return success