# dnsrecon/core/session_manager.py import threading import time import uuid import redis import pickle from typing import Dict, Optional, Any, List from core.scanner import Scanner # WARNING: Using pickle can be a security risk if the data source is not trusted. # In this case, we are only serializing/deserializing our own trusted Scanner objects, # which is generally safe. Do not unpickle data from untrusted sources. class SessionManager: """ Manages multiple scanner instances for concurrent user sessions using Redis. Enhanced with reliable cross-process stop signal management and immediate state updates. """ def __init__(self, session_timeout_minutes: int = 60): """ Initialize session manager with a Redis backend. """ self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.session_timeout = session_timeout_minutes * 60 # Convert to seconds self.lock = threading.Lock() # Lock for local operations, Redis handles atomic ops # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() print(f"SessionManager initialized with Redis backend and {session_timeout_minutes}min timeout") def __getstate__(self): """Prepare SessionManager for pickling.""" state = self.__dict__.copy() # Exclude unpickleable attributes - Redis client and threading objects unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client'] for attr in unpicklable_attrs: if attr in state: del state[attr] return state def __setstate__(self, state): """Restore SessionManager after unpickling.""" self.__dict__.update(state) # Re-initialize unpickleable attributes import redis self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.lock = threading.Lock() self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() def _get_session_key(self, session_id: str) -> str: """Generates the Redis key for a session.""" return f"dnsrecon:session:{session_id}" def _get_stop_signal_key(self, session_id: str) -> str: """Generates the Redis key for a session's stop signal.""" return f"dnsrecon:stop:{session_id}" def create_session(self) -> str: """ Create a new user session and store it in Redis. """ session_id = str(uuid.uuid4()) print(f"=== CREATING SESSION {session_id} IN REDIS ===") try: from core.session_config import create_session_config session_config = create_session_config() scanner_instance = Scanner(session_config=session_config) # Set the session ID on the scanner for cross-process stop signal management scanner_instance.session_id = session_id session_data = { 'scanner': scanner_instance, 'config': session_config, 'created_at': time.time(), 'last_activity': time.time(), 'status': 'active' } # Serialize the entire session data dictionary using pickle serialized_data = pickle.dumps(session_data) # Store in Redis session_key = self._get_session_key(session_id) self.redis_client.setex(session_key, self.session_timeout, serialized_data) # Initialize stop signal as False stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') print(f"Session {session_id} stored in Redis with stop signal initialized") return session_id except Exception as e: print(f"ERROR: Failed to create session {session_id}: {e}") raise def set_stop_signal(self, session_id: str) -> bool: """ Set the stop signal for a session (cross-process safe). Args: session_id: Session identifier Returns: bool: True if signal was set successfully """ try: stop_key = self._get_stop_signal_key(session_id) # Set stop signal to '1' with the same TTL as the session self.redis_client.setex(stop_key, self.session_timeout, b'1') print(f"Stop signal set for session {session_id}") return True except Exception as e: print(f"ERROR: Failed to set stop signal for session {session_id}: {e}") return False def is_stop_requested(self, session_id: str) -> bool: """ Check if stop is requested for a session (cross-process safe). Args: session_id: Session identifier Returns: bool: True if stop is requested """ try: stop_key = self._get_stop_signal_key(session_id) value = self.redis_client.get(stop_key) return value == b'1' if value is not None else False except Exception as e: print(f"ERROR: Failed to check stop signal for session {session_id}: {e}") return False def clear_stop_signal(self, session_id: str) -> bool: """ Clear the stop signal for a session. Args: session_id: Session identifier Returns: bool: True if signal was cleared successfully """ try: stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') print(f"Stop signal cleared for session {session_id}") return True except Exception as e: print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}") return False def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: """Retrieves and deserializes session data from Redis.""" try: session_key = self._get_session_key(session_id) serialized_data = self.redis_client.get(session_key) if serialized_data: session_data = pickle.loads(serialized_data) # Ensure the scanner has the correct session ID for stop signal checking if 'scanner' in session_data and session_data['scanner']: session_data['scanner'].session_id = session_id return session_data return None except Exception as e: print(f"ERROR: Failed to get session data for {session_id}: {e}") return None def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: """ Serializes and saves session data back to Redis with updated TTL. Returns: bool: True if save was successful """ try: session_key = self._get_session_key(session_id) serialized_data = pickle.dumps(session_data) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) return result except Exception as e: print(f"ERROR: Failed to save session data for {session_id}: {e}") return False def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: """ Updates just the scanner object in a session with immediate persistence. Returns: bool: True if update was successful """ try: session_data = self._get_session_data(session_id) if session_data: # Ensure scanner has the session ID scanner.session_id = session_id session_data['scanner'] = scanner session_data['last_activity'] = time.time() # Immediately save to Redis for GUI updates success = self._save_session_data(session_id, session_data) if success: print(f"Scanner state updated for session {session_id} (status: {scanner.status})") else: print(f"WARNING: Failed to save scanner state for session {session_id}") return success else: print(f"WARNING: Session {session_id} not found for scanner update") return False except Exception as e: print(f"ERROR: Failed to update scanner for session {session_id}: {e}") return False def update_scanner_status(self, session_id: str, status: str) -> bool: """ Quickly update just the scanner status for immediate GUI feedback. Args: session_id: Session identifier status: New scanner status Returns: bool: True if update was successful """ try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: session_data['scanner'].status = status session_data['last_activity'] = time.time() success = self._save_session_data(session_id, session_data) if success: print(f"Scanner status updated to '{status}' for session {session_id}") else: print(f"WARNING: Failed to save status update for session {session_id}") return success return False except Exception as e: print(f"ERROR: Failed to update scanner status for session {session_id}: {e}") return False def get_session(self, session_id: str) -> Optional[Scanner]: """ Get scanner instance for a session from Redis with enhanced session ID management. """ if not session_id: return None session_data = self._get_session_data(session_id) if not session_data or session_data.get('status') != 'active': return None # Update last activity and save back to Redis session_data['last_activity'] = time.time() self._save_session_data(session_id, session_data) scanner = session_data.get('scanner') if scanner: # Ensure the scanner can check the Redis-based stop signal scanner.session_id = session_id return scanner def get_session_status_only(self, session_id: str) -> Optional[str]: """ Get just the scanner status without full session retrieval (for performance). Args: session_id: Session identifier Returns: Scanner status string or None if not found """ try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: return session_data['scanner'].status return None except Exception as e: print(f"ERROR: Failed to get session status for {session_id}: {e}") return None def terminate_session(self, session_id: str) -> bool: """ Terminate a specific session in Redis with reliable stop signal and immediate status update. """ print(f"=== TERMINATING SESSION {session_id} ===") try: # First, set the stop signal self.set_stop_signal(session_id) # Update scanner status to stopped immediately for GUI feedback self.update_scanner_status(session_id, 'stopped') session_data = self._get_session_data(session_id) if not session_data: print(f"Session {session_id} not found") return False scanner = session_data.get('scanner') if scanner and scanner.status == 'running': print(f"Stopping scan for session: {session_id}") # The scanner will check the Redis stop signal scanner.stop_scan() # Update the scanner state immediately self.update_session_scanner(session_id, scanner) # Wait a moment for graceful shutdown time.sleep(0.5) # Delete session data and stop signal from Redis session_key = self._get_session_key(session_id) stop_key = self._get_stop_signal_key(session_id) self.redis_client.delete(session_key) self.redis_client.delete(stop_key) print(f"Terminated and removed session from Redis: {session_id}") return True except Exception as e: print(f"ERROR: Failed to terminate session {session_id}: {e}") return False def _cleanup_loop(self) -> None: """ Background thread to cleanup inactive sessions and orphaned stop signals. """ while True: try: # Clean up orphaned stop signals stop_keys = self.redis_client.keys("dnsrecon:stop:*") for stop_key in stop_keys: # Extract session ID from stop key session_id = stop_key.decode('utf-8').split(':')[-1] session_key = self._get_session_key(session_id) # If session doesn't exist but stop signal does, clean it up if not self.redis_client.exists(session_key): self.redis_client.delete(stop_key) print(f"Cleaned up orphaned stop signal for session {session_id}") except Exception as e: print(f"Error in cleanup loop: {e}") time.sleep(300) # Sleep for 5 minutes def list_active_sessions(self) -> List[Dict[str, Any]]: """List all active sessions for admin purposes.""" try: session_keys = self.redis_client.keys("dnsrecon:session:*") sessions = [] for session_key in session_keys: session_id = session_key.decode('utf-8').split(':')[-1] session_data = self._get_session_data(session_id) if session_data: scanner = session_data.get('scanner') sessions.append({ 'session_id': session_id, 'created_at': session_data.get('created_at'), 'last_activity': session_data.get('last_activity'), 'scanner_status': scanner.status if scanner else 'unknown', 'current_target': scanner.current_target if scanner else None }) return sessions except Exception as e: print(f"ERROR: Failed to list active sessions: {e}") return [] def get_statistics(self) -> Dict[str, Any]: """Get session manager statistics.""" try: session_keys = self.redis_client.keys("dnsrecon:session:*") stop_keys = self.redis_client.keys("dnsrecon:stop:*") active_sessions = len(session_keys) running_scans = 0 for session_key in session_keys: session_id = session_key.decode('utf-8').split(':')[-1] status = self.get_session_status_only(session_id) if status == 'running': running_scans += 1 return { 'total_active_sessions': active_sessions, 'running_scans': running_scans, 'total_stop_signals': len(stop_keys) } except Exception as e: print(f"ERROR: Failed to get statistics: {e}") return { 'total_active_sessions': 0, 'running_scans': 0, 'total_stop_signals': 0 } # Global session manager instance session_manager = SessionManager(session_timeout_minutes=60)