# dnsrecon/core/session_manager.py import threading import time import uuid import redis import pickle import hashlib from typing import Dict, Optional, Any, List, Tuple from core.scanner import Scanner class UserIdentifier: """Handles user identification for session management.""" @staticmethod def generate_user_fingerprint(client_ip: str, user_agent: str) -> str: """ Generate a unique fingerprint for a user based on IP and User-Agent. Args: client_ip: Client IP address user_agent: User-Agent header value Returns: Unique user fingerprint hash """ # Create deterministic user identifier user_data = f"{client_ip}:{user_agent[:100]}" # Limit UA to 100 chars fingerprint = hashlib.sha256(user_data.encode()).hexdigest()[:16] # 16 char fingerprint return f"user_{fingerprint}" @staticmethod def extract_request_info(request) -> Tuple[str, str]: """ Extract client IP and User-Agent from Flask request. Args: request: Flask request object Returns: Tuple of (client_ip, user_agent) """ # Handle proxy headers for real IP client_ip = request.headers.get('X-Forwarded-For', '').split(',')[0].strip() if not client_ip: client_ip = request.headers.get('X-Real-IP', '') if not client_ip: client_ip = request.remote_addr or 'unknown' user_agent = request.headers.get('User-Agent', 'unknown') return client_ip, user_agent class SessionConsolidator: """Handles consolidation of session data when replacing sessions.""" @staticmethod def consolidate_scanner_data(old_scanner: 'Scanner', new_scanner: 'Scanner') -> 'Scanner': """ Consolidate useful data from old scanner into new scanner. Args: old_scanner: Scanner from terminated session new_scanner: New scanner instance Returns: Enhanced new scanner with consolidated data """ try: # Consolidate graph data if old scanner has valuable data if old_scanner and hasattr(old_scanner, 'graph') and old_scanner.graph: old_stats = old_scanner.graph.get_statistics() if old_stats['basic_metrics']['total_nodes'] > 0: print(f"Consolidating graph data: {old_stats['basic_metrics']['total_nodes']} nodes, {old_stats['basic_metrics']['total_edges']} edges") # Transfer nodes and edges to new scanner's graph for node_id, node_data in old_scanner.graph.graph.nodes(data=True): # Add node to new graph with all attributes new_scanner.graph.graph.add_node(node_id, **node_data) for source, target, edge_data in old_scanner.graph.graph.edges(data=True): # Add edge to new graph with all attributes new_scanner.graph.graph.add_edge(source, target, **edge_data) # Update correlation index if hasattr(old_scanner.graph, 'correlation_index'): new_scanner.graph.correlation_index = old_scanner.graph.correlation_index.copy() # Update timestamps new_scanner.graph.creation_time = old_scanner.graph.creation_time new_scanner.graph.last_modified = old_scanner.graph.last_modified # Consolidate provider statistics if old_scanner and hasattr(old_scanner, 'providers') and old_scanner.providers: for old_provider in old_scanner.providers: # Find matching provider in new scanner matching_new_provider = None for new_provider in new_scanner.providers: if new_provider.get_name() == old_provider.get_name(): matching_new_provider = new_provider break if matching_new_provider: # Transfer cumulative statistics matching_new_provider.total_requests += old_provider.total_requests matching_new_provider.successful_requests += old_provider.successful_requests matching_new_provider.failed_requests += old_provider.failed_requests matching_new_provider.total_relationships_found += old_provider.total_relationships_found # Transfer cache statistics if available if hasattr(old_provider, 'cache_hits'): matching_new_provider.cache_hits += getattr(old_provider, 'cache_hits', 0) matching_new_provider.cache_misses += getattr(old_provider, 'cache_misses', 0) print(f"Consolidated {old_provider.get_name()} provider stats: {old_provider.total_requests} requests") return new_scanner except Exception as e: print(f"Warning: Error during session consolidation: {e}") return new_scanner class SessionManager: """ Manages single scanner session per user using Redis with user identification. Enforces one active session per user for consistent state management. """ def __init__(self, session_timeout_minutes: int = 60): """ Initialize session manager with Redis backend and user tracking. """ self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.session_timeout = session_timeout_minutes * 60 # Convert to seconds self.lock = threading.Lock() # User identification helper self.user_identifier = UserIdentifier() self.consolidator = SessionConsolidator() # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() print(f"SessionManager initialized with Redis backend, user tracking, and {session_timeout_minutes}min timeout") def __getstate__(self): """Prepare SessionManager for pickling.""" state = self.__dict__.copy() # Exclude unpickleable attributes unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client'] for attr in unpicklable_attrs: if attr in state: del state[attr] return state def __setstate__(self, state): """Restore SessionManager after unpickling.""" self.__dict__.update(state) # Re-initialize unpickleable attributes import redis self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.lock = threading.Lock() self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() def _get_session_key(self, session_id: str) -> str: """Generate Redis key for a session.""" return f"dnsrecon:session:{session_id}" def _get_user_session_key(self, user_fingerprint: str) -> str: """Generate Redis key for user -> session mapping.""" return f"dnsrecon:user:{user_fingerprint}" def _get_stop_signal_key(self, session_id: str) -> str: """Generate Redis key for session stop signal.""" return f"dnsrecon:stop:{session_id}" def create_or_replace_user_session(self, client_ip: str, user_agent: str) -> str: """ Create new session for user, replacing any existing session. Consolidates data from previous session if it exists. Args: client_ip: Client IP address user_agent: User-Agent header Returns: New session ID """ user_fingerprint = self.user_identifier.generate_user_fingerprint(client_ip, user_agent) new_session_id = str(uuid.uuid4()) print(f"=== CREATING/REPLACING SESSION FOR USER {user_fingerprint} ===") try: # Check for existing user session existing_session_id = self._get_user_current_session(user_fingerprint) old_scanner = None if existing_session_id: print(f"Found existing session {existing_session_id} for user {user_fingerprint}") # Get old scanner data for consolidation old_scanner = self.get_session(existing_session_id) # Terminate old session self._terminate_session_internal(existing_session_id, cleanup_user_mapping=False) print(f"Terminated old session {existing_session_id}") # Create new session config and scanner from core.session_config import create_session_config session_config = create_session_config() new_scanner = Scanner(session_config=session_config) # Set session ID on scanner for cross-process operations new_scanner.session_id = new_session_id # Consolidate data from old session if available if old_scanner: new_scanner = self.consolidator.consolidate_scanner_data(old_scanner, new_scanner) print(f"Consolidated data from previous session") # Create session data session_data = { 'scanner': new_scanner, 'config': session_config, 'created_at': time.time(), 'last_activity': time.time(), 'status': 'active', 'user_fingerprint': user_fingerprint, 'client_ip': client_ip, 'user_agent': user_agent[:200] # Truncate for storage } # Store session in Redis session_key = self._get_session_key(new_session_id) serialized_data = pickle.dumps(session_data) self.redis_client.setex(session_key, self.session_timeout, serialized_data) # Update user -> session mapping user_session_key = self._get_user_session_key(user_fingerprint) self.redis_client.setex(user_session_key, self.session_timeout, new_session_id.encode('utf-8')) # Initialize stop signal stop_key = self._get_stop_signal_key(new_session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') print(f"Created new session {new_session_id} for user {user_fingerprint}") return new_session_id except Exception as e: print(f"ERROR: Failed to create session for user {user_fingerprint}: {e}") raise def _get_user_current_session(self, user_fingerprint: str) -> Optional[str]: """Get current session ID for a user.""" try: user_session_key = self._get_user_session_key(user_fingerprint) session_id_bytes = self.redis_client.get(user_session_key) if session_id_bytes: return session_id_bytes.decode('utf-8') return None except Exception as e: print(f"Error getting user session: {e}") return None def set_stop_signal(self, session_id: str) -> bool: """Set stop signal for session (cross-process safe).""" try: stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'1') print(f"Stop signal set for session {session_id}") return True except Exception as e: print(f"ERROR: Failed to set stop signal for session {session_id}: {e}") return False def is_stop_requested(self, session_id: str) -> bool: """Check if stop is requested for session (cross-process safe).""" try: stop_key = self._get_stop_signal_key(session_id) value = self.redis_client.get(stop_key) return value == b'1' if value is not None else False except Exception as e: print(f"ERROR: Failed to check stop signal for session {session_id}: {e}") return False def clear_stop_signal(self, session_id: str) -> bool: """Clear stop signal for session.""" try: stop_key = self._get_stop_signal_key(session_id) self.redis_client.setex(stop_key, self.session_timeout, b'0') print(f"Stop signal cleared for session {session_id}") return True except Exception as e: print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}") return False def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: """Retrieve and deserialize session data from Redis.""" try: session_key = self._get_session_key(session_id) serialized_data = self.redis_client.get(session_key) if serialized_data: session_data = pickle.loads(serialized_data) # Ensure scanner has correct session ID if 'scanner' in session_data and session_data['scanner']: session_data['scanner'].session_id = session_id return session_data return None except Exception as e: print(f"ERROR: Failed to get session data for {session_id}: {e}") return None def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: """Serialize and save session data to Redis with updated TTL.""" try: session_key = self._get_session_key(session_id) serialized_data = pickle.dumps(session_data) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) # Also refresh user mapping TTL if available if 'user_fingerprint' in session_data: user_session_key = self._get_user_session_key(session_data['user_fingerprint']) self.redis_client.setex(user_session_key, self.session_timeout, session_id.encode('utf-8')) return result except Exception as e: print(f"ERROR: Failed to save session data for {session_id}: {e}") return False def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: """Update scanner object in session with immediate persistence.""" try: session_data = self._get_session_data(session_id) if session_data: # Ensure scanner has session ID scanner.session_id = session_id session_data['scanner'] = scanner session_data['last_activity'] = time.time() success = self._save_session_data(session_id, session_data) if success: print(f"Scanner state updated for session {session_id} (status: {scanner.status})") else: print(f"WARNING: Failed to save scanner state for session {session_id}") return success else: print(f"WARNING: Session {session_id} not found for scanner update") return False except Exception as e: print(f"ERROR: Failed to update scanner for session {session_id}: {e}") return False def update_scanner_status(self, session_id: str, status: str) -> bool: """Quickly update scanner status for immediate GUI feedback.""" try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: session_data['scanner'].status = status session_data['last_activity'] = time.time() success = self._save_session_data(session_id, session_data) if success: print(f"Scanner status updated to '{status}' for session {session_id}") else: print(f"WARNING: Failed to save status update for session {session_id}") return success return False except Exception as e: print(f"ERROR: Failed to update scanner status for session {session_id}: {e}") return False def get_session(self, session_id: str) -> Optional[Scanner]: """Get scanner instance for session with session ID management.""" if not session_id: return None session_data = self._get_session_data(session_id) if not session_data or session_data.get('status') != 'active': return None # Update last activity and save back to Redis session_data['last_activity'] = time.time() self._save_session_data(session_id, session_data) scanner = session_data.get('scanner') if scanner: # Ensure scanner can check Redis-based stop signal scanner.session_id = session_id return scanner def get_session_status_only(self, session_id: str) -> Optional[str]: """Get scanner status without full session retrieval (for performance).""" try: session_data = self._get_session_data(session_id) if session_data and 'scanner' in session_data: return session_data['scanner'].status return None except Exception as e: print(f"ERROR: Failed to get session status for {session_id}: {e}") return None def terminate_session(self, session_id: str) -> bool: """Terminate specific session with reliable stop signal and immediate status update.""" return self._terminate_session_internal(session_id, cleanup_user_mapping=True) def _terminate_session_internal(self, session_id: str, cleanup_user_mapping: bool = True) -> bool: """Internal session termination with configurable user mapping cleanup.""" print(f"=== TERMINATING SESSION {session_id} ===") try: # Set stop signal first self.set_stop_signal(session_id) # Update scanner status immediately for GUI feedback self.update_scanner_status(session_id, 'stopped') session_data = self._get_session_data(session_id) if not session_data: print(f"Session {session_id} not found") return False scanner = session_data.get('scanner') if scanner and scanner.status == 'running': print(f"Stopping scan for session: {session_id}") scanner.stop_scan() self.update_session_scanner(session_id, scanner) # Wait for graceful shutdown time.sleep(0.5) # Clean up user mapping if requested if cleanup_user_mapping and 'user_fingerprint' in session_data: user_session_key = self._get_user_session_key(session_data['user_fingerprint']) self.redis_client.delete(user_session_key) print(f"Cleaned up user mapping for {session_data['user_fingerprint']}") # Delete session data and stop signal session_key = self._get_session_key(session_id) stop_key = self._get_stop_signal_key(session_id) self.redis_client.delete(session_key) self.redis_client.delete(stop_key) print(f"Terminated and removed session from Redis: {session_id}") return True except Exception as e: print(f"ERROR: Failed to terminate session {session_id}: {e}") return False def _cleanup_loop(self) -> None: """Background thread to cleanup inactive sessions and orphaned signals.""" while True: try: # Clean up orphaned stop signals stop_keys = self.redis_client.keys("dnsrecon:stop:*") for stop_key in stop_keys: session_id = stop_key.decode('utf-8').split(':')[-1] session_key = self._get_session_key(session_id) if not self.redis_client.exists(session_key): self.redis_client.delete(stop_key) print(f"Cleaned up orphaned stop signal for session {session_id}") # Clean up orphaned user mappings user_keys = self.redis_client.keys("dnsrecon:user:*") for user_key in user_keys: session_id_bytes = self.redis_client.get(user_key) if session_id_bytes: session_id = session_id_bytes.decode('utf-8') session_key = self._get_session_key(session_id) if not self.redis_client.exists(session_key): self.redis_client.delete(user_key) print(f"Cleaned up orphaned user mapping for session {session_id}") except Exception as e: print(f"Error in cleanup loop: {e}") time.sleep(300) # Sleep for 5 minutes def list_active_sessions(self) -> List[Dict[str, Any]]: """List all active sessions for admin purposes.""" try: session_keys = self.redis_client.keys("dnsrecon:session:*") sessions = [] for session_key in session_keys: session_id = session_key.decode('utf-8').split(':')[-1] session_data = self._get_session_data(session_id) if session_data: scanner = session_data.get('scanner') sessions.append({ 'session_id': session_id, 'user_fingerprint': session_data.get('user_fingerprint', 'unknown'), 'client_ip': session_data.get('client_ip', 'unknown'), 'created_at': session_data.get('created_at'), 'last_activity': session_data.get('last_activity'), 'scanner_status': scanner.status if scanner else 'unknown', 'current_target': scanner.current_target if scanner else None }) return sessions except Exception as e: print(f"ERROR: Failed to list active sessions: {e}") return [] def get_statistics(self) -> Dict[str, Any]: """Get session manager statistics.""" try: session_keys = self.redis_client.keys("dnsrecon:session:*") user_keys = self.redis_client.keys("dnsrecon:user:*") stop_keys = self.redis_client.keys("dnsrecon:stop:*") active_sessions = len(session_keys) unique_users = len(user_keys) running_scans = 0 for session_key in session_keys: session_id = session_key.decode('utf-8').split(':')[-1] status = self.get_session_status_only(session_id) if status == 'running': running_scans += 1 return { 'total_active_sessions': active_sessions, 'unique_users': unique_users, 'running_scans': running_scans, 'total_stop_signals': len(stop_keys), 'average_sessions_per_user': round(active_sessions / unique_users, 2) if unique_users > 0 else 0 } except Exception as e: print(f"ERROR: Failed to get statistics: {e}") return { 'total_active_sessions': 0, 'unique_users': 0, 'running_scans': 0, 'total_stop_signals': 0, 'average_sessions_per_user': 0 } def get_session_info(self, session_id: str) -> Dict[str, Any]: """Get detailed information about a specific session.""" try: session_data = self._get_session_data(session_id) if not session_data: return {'error': 'Session not found'} scanner = session_data.get('scanner') return { 'session_id': session_id, 'user_fingerprint': session_data.get('user_fingerprint', 'unknown'), 'client_ip': session_data.get('client_ip', 'unknown'), 'user_agent': session_data.get('user_agent', 'unknown'), 'created_at': session_data.get('created_at'), 'last_activity': session_data.get('last_activity'), 'status': session_data.get('status'), 'scanner_status': scanner.status if scanner else 'unknown', 'current_target': scanner.current_target if scanner else None, 'session_age_minutes': round((time.time() - session_data.get('created_at', time.time())) / 60, 1) } except Exception as e: print(f"ERROR: Failed to get session info for {session_id}: {e}") return {'error': f'Failed to get session info: {str(e)}'} # Global session manager instance session_manager = SessionManager(session_timeout_minutes=60)