# dnsrecon/core/session_manager.py import threading import time import uuid import redis import pickle from typing import Dict, Optional, Any import copy from core.scanner import Scanner from config import config class SessionManager: """ FIXED: Manages multiple scanner instances for concurrent user sessions using Redis. Enhanced to properly maintain WebSocket connections throughout scan lifecycle. """ def __init__(self, session_timeout_minutes: int = 0): """ Initialize session manager with a Redis backend. """ if session_timeout_minutes is None: session_timeout_minutes = config.session_timeout_minutes self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.session_timeout = session_timeout_minutes * 60 # Convert to seconds self.lock = threading.Lock() # FIXED: Add a creation lock to prevent race conditions self.creation_lock = threading.Lock() # Track active socketio connections per session self.active_socketio_connections = {} # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() print(f"SessionManager initialized with Redis backend and {session_timeout_minutes}min timeout") def __getstate__(self): """Prepare SessionManager for pickling.""" state = self.__dict__.copy() # Exclude unpickleable attributes - Redis client and threading objects unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections'] for attr in unpicklable_attrs: if attr in state: del state[attr] return state def __setstate__(self, state): """Restore SessionManager after unpickling.""" self.__dict__.update(state) # Re-initialize unpickleable attributes self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.lock = threading.Lock() self.creation_lock = threading.Lock() self.active_socketio_connections = {} self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() def _get_session_key(self, session_id: str) -> str: """Generates the Redis key for a session.""" return f"dnsrecon:session:{session_id}" def _get_stop_signal_key(self, session_id: str) -> str: """Generates the Redis key for a session's stop signal.""" return f"dnsrecon:stop:{session_id}" def register_socketio_connection(self, session_id: str, socketio) -> None: """ FIXED: Register a socketio connection for a session. This ensures the connection is maintained throughout the session lifecycle. """ with self.lock: self.active_socketio_connections[session_id] = socketio print(f"Registered socketio connection for session {session_id}") def get_socketio_connection(self, session_id: str): """ FIXED: Get the active socketio connection for a session. """ with self.lock: return self.active_socketio_connections.get(session_id) def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner: """ FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects. Now preserves socketio connection info for restoration. """ # Set the session ID on the scanner for cross-process stop signal management scanner.session_id = session_id # FIXED: Don't set socketio to None if we want to preserve real-time updates # Instead, we'll restore it when loading the scanner scanner.socketio = None # Force cleanup of any threading objects that might cause issues if hasattr(scanner, 'stop_event'): scanner.stop_event = None if hasattr(scanner, 'scan_thread'): scanner.scan_thread = None if hasattr(scanner, 'executor'): scanner.executor = None if hasattr(scanner, 'status_logger_thread'): scanner.status_logger_thread = None if hasattr(scanner, 'status_logger_stop_event'): scanner.status_logger_stop_event = None return scanner def create_session(self, socketio=None) -> str: """ FIXED: Create a new user session with enhanced WebSocket management. """ # FIXED: Use creation lock to prevent race conditions with self.creation_lock: session_id = str(uuid.uuid4()) print(f"=== CREATING SESSION {session_id} IN REDIS ===") # FIXED: Register socketio connection first if socketio: self.register_socketio_connection(session_id, socketio) try: from core.session_config import create_session_config session_config = create_session_config() # Create scanner WITHOUT socketio to avoid weakref issues scanner_instance = Scanner(session_config=session_config, socketio=None) # Prepare scanner for storage (removes problematic objects) scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id) session_data = { 'scanner': scanner_instance, 'config': session_config, 'created_at': time.time(), 'last_activity': time.time(), 'status': 'active' } # Test serialization before storing to catch issues early try: test_serialization = pickle.dumps(session_data) print(f"Session serialization test successful ({len(test_serialization)} bytes)") except Exception as pickle_error: print(f"PICKLE TEST FAILED: {pickle_error}") # Try to identify the problematic object for key, value in session_data.items(): try: pickle.dumps(value) print(f" {key}: OK") except Exception as item_error: print(f" {key}: FAILED - {item_error}") raise pickle_error # Store in Redis session_key = self._get_session_key(session_id) self.redis_client.setex(session_key, self.session_timeout, test_serialization) # Initialize stop signal as False stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') print(f"Session {session_id} stored in Redis with stop signal initialized") print(f"Session has {len(scanner_instance.providers)} providers: {[p.get_name() for p in scanner_instance.providers]}") return session_id except Exception as e: print(f"ERROR: Failed to create session {session_id}: {e}") import traceback traceback.print_exc() raise def set_stop_signal(self, session_id: str) -> bool: """ Set the stop signal for a session (cross-process safe). Args: session_id: Session identifier Returns: bool: True if signal was set successfully """ try: stop_key = self._get_stop_signal_key(session_id) # Set stop signal to '1' with the same TTL as the session self.redis_client.setex(stop_key, self.session_timeout, b'1') print(f"Stop signal set for session {session_id}") return True except Exception as e: print(f"ERROR: Failed to set stop signal for session {session_id}: {e}") return False def is_stop_requested(self, session_id: str) -> bool: """ Check if stop is requested for a session (cross-process safe). Args: session_id: Session identifier Returns: bool: True if stop is requested """ try: stop_key = self._get_stop_signal_key(session_id) value = self.redis_client.get(stop_key) return value == b'1' if value is not None else False except Exception as e: print(f"ERROR: Failed to check stop signal for session {session_id}: {e}") return False def clear_stop_signal(self, session_id: str) -> bool: """ Clear the stop signal for a session. Args: session_id: Session identifier Returns: bool: True if signal was cleared successfully """ try: stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') print(f"Stop signal cleared for session {session_id}") return True except Exception as e: print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}") return False def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: """Retrieves and deserializes session data from Redis.""" try: session_key = self._get_session_key(session_id) serialized_data = self.redis_client.get(session_key) if serialized_data: session_data = pickle.loads(serialized_data) # Ensure the scanner has the correct session ID for stop signal checking if 'scanner' in session_data and session_data['scanner']: session_data['scanner'].session_id = session_id # FIXED: Restore socketio connection from our registry socketio_conn = self.get_socketio_connection(session_id) if socketio_conn: session_data['scanner'].socketio = socketio_conn print(f"Restored socketio connection for session {session_id}") else: print(f"No socketio connection found for session {session_id}") session_data['scanner'].socketio = None return session_data return None except Exception as e: print(f"ERROR: Failed to get session data for {session_id}: {e}") import traceback traceback.print_exc() return None def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: """ Serializes and saves session data back to Redis with updated TTL. FIXED: Now preserves socketio connection during storage. Returns: bool: True if save was successful """ try: session_key = self._get_session_key(session_id) # Create a deep copy to avoid modifying the original scanner object session_data_to_save = copy.deepcopy(session_data) # Prepare scanner for storage if it exists if 'scanner' in session_data_to_save and session_data_to_save['scanner']: # FIXED: Preserve the original socketio connection before preparing for storage original_socketio = session_data_to_save['scanner'].socketio session_data_to_save['scanner'] = self._prepare_scanner_for_storage( session_data_to_save['scanner'], session_id ) # FIXED: If we had a socketio connection, make sure it's registered if original_socketio and session_id not in self.active_socketio_connections: self.register_socketio_connection(session_id, original_socketio) serialized_data = pickle.dumps(session_data_to_save) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) return result except Exception as e: print(f"ERROR: Failed to save session data for {session_id}: {e}") import traceback traceback.print_exc() return False def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: """ FIXED: Updates just the scanner object in a session with immediate persistence. Now maintains socketio connection throughout the update process. Returns: bool: True if update was successful """ try: session_data = self._get_session_data(session_id) if session_data: # FIXED: Preserve socketio connection before preparing for storage original_socketio = scanner.socketio # Prepare scanner for storage scanner = self._prepare_scanner_for_storage(scanner, session_id) session_data['scanner'] = scanner session_data['last_activity'] = time.time() # FIXED: Restore socketio connection after preparation if original_socketio: self.register_socketio_connection(session_id, original_socketio) session_data['scanner'].socketio = original_socketio # Immediately save to Redis for GUI updates success = self._save_session_data(session_id, session_data) if success: # Only log occasionally to reduce noise if hasattr(self, '_last_update_log'): if time.time() - self._last_update_log > 5: # Log every 5 seconds max self._last_update_log = time.time() else: self._last_update_log = time.time() else: print(f"WARNING: Failed to save scanner state for session {session_id}") return success else: print(f"WARNING: Session {session_id} not found for scanner update") return False except Exception as e: print(f"ERROR: Failed to update scanner for session {session_id}: {e}") import traceback traceback.print_exc() return False def update_scanner_status(self, session_id: str, status: str) -> bool: """ Quickly update just the scanner status for immediate GUI feedback. Args: session_id: Session identifier status: New scanner status Returns: bool: True if update was successful """ try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: session_data['scanner'].status = status session_data['last_activity'] = time.time() success = self._save_session_data(session_id, session_data) if success: print(f"Scanner status updated to '{status}' for session {session_id}") else: print(f"WARNING: Failed to save status update for session {session_id}") return success return False except Exception as e: print(f"ERROR: Failed to update scanner status for session {session_id}: {e}") return False def get_session(self, session_id: str) -> Optional[Scanner]: """ FIXED: Get scanner instance for a session from Redis with proper socketio restoration. """ if not session_id: return None session_data = self._get_session_data(session_id) if not session_data or session_data.get('status') != 'active': return None # Update last activity and save back to Redis session_data['last_activity'] = time.time() self._save_session_data(session_id, session_data) scanner = session_data.get('scanner') if scanner: # Ensure the scanner can check the Redis-based stop signal scanner.session_id = session_id # FIXED: Restore socketio connection from our registry socketio_conn = self.get_socketio_connection(session_id) if socketio_conn: scanner.socketio = socketio_conn print(f"✓ Restored socketio connection for session {session_id}") else: scanner.socketio = None print(f"⚠️ No socketio connection found for session {session_id}") return scanner def get_session_status_only(self, session_id: str) -> Optional[str]: """ Get just the scanner status without full session retrieval (for performance). Args: session_id: Session identifier Returns: Scanner status string or None if not found """ try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: return session_data['scanner'].status return None except Exception as e: print(f"ERROR: Failed to get session status for {session_id}: {e}") return None def terminate_session(self, session_id: str) -> bool: """ Terminate a specific session in Redis with reliable stop signal and immediate status update. """ print(f"=== TERMINATING SESSION {session_id} ===") try: # First, set the stop signal self.set_stop_signal(session_id) # Update scanner status to stopped immediately for GUI feedback self.update_scanner_status(session_id, 'stopped') session_data = self._get_session_data(session_id) if not session_data: print(f"Session {session_id} not found") return False scanner = session_data.get('scanner') if scanner and scanner.status == 'running': print(f"Stopping scan for session: {session_id}") # The scanner will check the Redis stop signal scanner.stop_scan() # Update the scanner state immediately self.update_session_scanner(session_id, scanner) # Wait a moment for graceful shutdown time.sleep(0.5) # FIXED: Clean up socketio connection with self.lock: if session_id in self.active_socketio_connections: del self.active_socketio_connections[session_id] print(f"Cleaned up socketio connection for session {session_id}") # Delete session data and stop signal from Redis session_key = self._get_session_key(session_id) stop_key = self._get_stop_signal_key(session_id) self.redis_client.delete(session_key) self.redis_client.delete(stop_key) print(f"Terminated and removed session from Redis: {session_id}") return True except Exception as e: print(f"ERROR: Failed to terminate session {session_id}: {e}") import traceback traceback.print_exc() return False def _cleanup_loop(self) -> None: """ Background thread to cleanup inactive sessions and orphaned stop signals. """ while True: try: # Clean up orphaned stop signals stop_keys = self.redis_client.keys("dnsrecon:stop:*") for stop_key in stop_keys: # Extract session ID from stop key session_id = stop_key.decode('utf-8').split(':')[-1] session_key = self._get_session_key(session_id) # If session doesn't exist but stop signal does, clean it up if not self.redis_client.exists(session_key): self.redis_client.delete(stop_key) print(f"Cleaned up orphaned stop signal for session {session_id}") # Also clean up socketio connection with self.lock: if session_id in self.active_socketio_connections: del self.active_socketio_connections[session_id] print(f"Cleaned up orphaned socketio for session {session_id}") except Exception as e: print(f"Error in cleanup loop: {e}") time.sleep(300) # Sleep for 5 minutes def get_statistics(self) -> Dict[str, Any]: """Get session manager statistics.""" try: session_keys = self.redis_client.keys("dnsrecon:session:*") stop_keys = self.redis_client.keys("dnsrecon:stop:*") active_sessions = len(session_keys) running_scans = 0 for session_key in session_keys: session_id = session_key.decode('utf-8').split(':')[-1] status = self.get_session_status_only(session_id) if status == 'running': running_scans += 1 return { 'total_active_sessions': active_sessions, 'running_scans': running_scans, 'total_stop_signals': len(stop_keys), 'active_socketio_connections': len(self.active_socketio_connections) } except Exception as e: print(f"ERROR: Failed to get statistics: {e}") return { 'total_active_sessions': 0, 'running_scans': 0, 'total_stop_signals': 0, 'active_socketio_connections': 0 } # Global session manager instance session_manager = SessionManager(session_timeout_minutes=60)