dnsrecon/core/session_manager.py
overcuriousity 4378146d0c it
2025-09-14 13:14:02 +02:00

576 lines
25 KiB
Python

# dnsrecon/core/session_manager.py
import threading
import time
import uuid
import redis
import pickle
import hashlib
from typing import Dict, Optional, Any, List, Tuple
from core.scanner import Scanner
class UserIdentifier:
"""Handles user identification for session management."""
@staticmethod
def generate_user_fingerprint(client_ip: str, user_agent: str) -> str:
"""
Generate a unique fingerprint for a user based on IP and User-Agent.
Args:
client_ip: Client IP address
user_agent: User-Agent header value
Returns:
Unique user fingerprint hash
"""
# Create deterministic user identifier
user_data = f"{client_ip}:{user_agent[:100]}" # Limit UA to 100 chars
fingerprint = hashlib.sha256(user_data.encode()).hexdigest()[:16] # 16 char fingerprint
return f"user_{fingerprint}"
@staticmethod
def extract_request_info(request) -> Tuple[str, str]:
"""
Extract client IP and User-Agent from Flask request.
Args:
request: Flask request object
Returns:
Tuple of (client_ip, user_agent)
"""
# Handle proxy headers for real IP
client_ip = request.headers.get('X-Forwarded-For', '').split(',')[0].strip()
if not client_ip:
client_ip = request.headers.get('X-Real-IP', '')
if not client_ip:
client_ip = request.remote_addr or 'unknown'
user_agent = request.headers.get('User-Agent', 'unknown')
return client_ip, user_agent
class SessionConsolidator:
"""Handles consolidation of session data when replacing sessions."""
@staticmethod
def consolidate_scanner_data(old_scanner: 'Scanner', new_scanner: 'Scanner') -> 'Scanner':
"""
Consolidate useful data from old scanner into new scanner.
Args:
old_scanner: Scanner from terminated session
new_scanner: New scanner instance
Returns:
Enhanced new scanner with consolidated data
"""
try:
# Consolidate graph data if old scanner has valuable data
if old_scanner and hasattr(old_scanner, 'graph') and old_scanner.graph:
old_stats = old_scanner.graph.get_statistics()
if old_stats['basic_metrics']['total_nodes'] > 0:
print(f"Consolidating graph data: {old_stats['basic_metrics']['total_nodes']} nodes, {old_stats['basic_metrics']['total_edges']} edges")
# Transfer nodes and edges to new scanner's graph
for node_id, node_data in old_scanner.graph.graph.nodes(data=True):
# Add node to new graph with all attributes
new_scanner.graph.graph.add_node(node_id, **node_data)
for source, target, edge_data in old_scanner.graph.graph.edges(data=True):
# Add edge to new graph with all attributes
new_scanner.graph.graph.add_edge(source, target, **edge_data)
# Update correlation index
if hasattr(old_scanner.graph, 'correlation_index'):
new_scanner.graph.correlation_index = old_scanner.graph.correlation_index.copy()
# Update timestamps
new_scanner.graph.creation_time = old_scanner.graph.creation_time
new_scanner.graph.last_modified = old_scanner.graph.last_modified
# Consolidate provider statistics
if old_scanner and hasattr(old_scanner, 'providers') and old_scanner.providers:
for old_provider in old_scanner.providers:
# Find matching provider in new scanner
matching_new_provider = None
for new_provider in new_scanner.providers:
if new_provider.get_name() == old_provider.get_name():
matching_new_provider = new_provider
break
if matching_new_provider:
# Transfer cumulative statistics
matching_new_provider.total_requests += old_provider.total_requests
matching_new_provider.successful_requests += old_provider.successful_requests
matching_new_provider.failed_requests += old_provider.failed_requests
matching_new_provider.total_relationships_found += old_provider.total_relationships_found
# Transfer cache statistics if available
if hasattr(old_provider, 'cache_hits'):
matching_new_provider.cache_hits += getattr(old_provider, 'cache_hits', 0)
matching_new_provider.cache_misses += getattr(old_provider, 'cache_misses', 0)
print(f"Consolidated {old_provider.get_name()} provider stats: {old_provider.total_requests} requests")
return new_scanner
except Exception as e:
print(f"Warning: Error during session consolidation: {e}")
return new_scanner
class SessionManager:
"""
Manages single scanner session per user using Redis with user identification.
Enforces one active session per user for consistent state management.
"""
def __init__(self, session_timeout_minutes: int = 60):
"""
Initialize session manager with Redis backend and user tracking.
"""
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.session_timeout = session_timeout_minutes * 60 # Convert to seconds
self.lock = threading.Lock()
# User identification helper
self.user_identifier = UserIdentifier()
self.consolidator = SessionConsolidator()
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
print(f"SessionManager initialized with Redis backend, user tracking, and {session_timeout_minutes}min timeout")
def __getstate__(self):
"""Prepare SessionManager for pickling."""
state = self.__dict__.copy()
# Exclude unpickleable attributes
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
return state
def __setstate__(self, state):
"""Restore SessionManager after unpickling."""
self.__dict__.update(state)
# Re-initialize unpickleable attributes
import redis
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.lock = threading.Lock()
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
def _get_session_key(self, session_id: str) -> str:
"""Generate Redis key for a session."""
return f"dnsrecon:session:{session_id}"
def _get_user_session_key(self, user_fingerprint: str) -> str:
"""Generate Redis key for user -> session mapping."""
return f"dnsrecon:user:{user_fingerprint}"
def _get_stop_signal_key(self, session_id: str) -> str:
"""Generate Redis key for session stop signal."""
return f"dnsrecon:stop:{session_id}"
def create_or_replace_user_session(self, client_ip: str, user_agent: str) -> str:
"""
Create new session for user, replacing any existing session.
Consolidates data from previous session if it exists.
Args:
client_ip: Client IP address
user_agent: User-Agent header
Returns:
New session ID
"""
user_fingerprint = self.user_identifier.generate_user_fingerprint(client_ip, user_agent)
new_session_id = str(uuid.uuid4())
print(f"=== CREATING/REPLACING SESSION FOR USER {user_fingerprint} ===")
try:
# Check for existing user session
existing_session_id = self._get_user_current_session(user_fingerprint)
old_scanner = None
if existing_session_id:
print(f"Found existing session {existing_session_id} for user {user_fingerprint}")
# Get old scanner data for consolidation
old_scanner = self.get_session(existing_session_id)
# Terminate old session
self._terminate_session_internal(existing_session_id, cleanup_user_mapping=False)
print(f"Terminated old session {existing_session_id}")
# Create new session config and scanner
from core.session_config import create_session_config
session_config = create_session_config()
new_scanner = Scanner(session_config=session_config)
# Set session ID on scanner for cross-process operations
new_scanner.session_id = new_session_id
# Consolidate data from old session if available
if old_scanner:
new_scanner = self.consolidator.consolidate_scanner_data(old_scanner, new_scanner)
print(f"Consolidated data from previous session")
# Create session data
session_data = {
'scanner': new_scanner,
'config': session_config,
'created_at': time.time(),
'last_activity': time.time(),
'status': 'active',
'user_fingerprint': user_fingerprint,
'client_ip': client_ip,
'user_agent': user_agent[:200] # Truncate for storage
}
# Store session in Redis
session_key = self._get_session_key(new_session_id)
serialized_data = pickle.dumps(session_data)
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
# Update user -> session mapping
user_session_key = self._get_user_session_key(user_fingerprint)
self.redis_client.setex(user_session_key, self.session_timeout, new_session_id.encode('utf-8'))
# Initialize stop signal
stop_key = self._get_stop_signal_key(new_session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'0')
print(f"Created new session {new_session_id} for user {user_fingerprint}")
return new_session_id
except Exception as e:
print(f"ERROR: Failed to create session for user {user_fingerprint}: {e}")
raise
def _get_user_current_session(self, user_fingerprint: str) -> Optional[str]:
"""Get current session ID for a user."""
try:
user_session_key = self._get_user_session_key(user_fingerprint)
session_id_bytes = self.redis_client.get(user_session_key)
if session_id_bytes:
return session_id_bytes.decode('utf-8')
return None
except Exception as e:
print(f"Error getting user session: {e}")
return None
def set_stop_signal(self, session_id: str) -> bool:
"""Set stop signal for session (cross-process safe)."""
try:
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'1')
print(f"Stop signal set for session {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to set stop signal for session {session_id}: {e}")
return False
def is_stop_requested(self, session_id: str) -> bool:
"""Check if stop is requested for session (cross-process safe)."""
try:
stop_key = self._get_stop_signal_key(session_id)
value = self.redis_client.get(stop_key)
return value == b'1' if value is not None else False
except Exception as e:
print(f"ERROR: Failed to check stop signal for session {session_id}: {e}")
return False
def clear_stop_signal(self, session_id: str) -> bool:
"""Clear stop signal for session."""
try:
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'0')
print(f"Stop signal cleared for session {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}")
return False
def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve and deserialize session data from Redis."""
try:
session_key = self._get_session_key(session_id)
serialized_data = self.redis_client.get(session_key)
if serialized_data:
session_data = pickle.loads(serialized_data)
# Ensure scanner has correct session ID
if 'scanner' in session_data and session_data['scanner']:
session_data['scanner'].session_id = session_id
return session_data
return None
except Exception as e:
print(f"ERROR: Failed to get session data for {session_id}: {e}")
return None
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
"""Serialize and save session data to Redis with updated TTL."""
try:
session_key = self._get_session_key(session_id)
serialized_data = pickle.dumps(session_data)
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
# Also refresh user mapping TTL if available
if 'user_fingerprint' in session_data:
user_session_key = self._get_user_session_key(session_data['user_fingerprint'])
self.redis_client.setex(user_session_key, self.session_timeout, session_id.encode('utf-8'))
return result
except Exception as e:
print(f"ERROR: Failed to save session data for {session_id}: {e}")
return False
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
"""Update scanner object in session with immediate persistence."""
try:
session_data = self._get_session_data(session_id)
if session_data:
# Ensure scanner has session ID
scanner.session_id = session_id
session_data['scanner'] = scanner
session_data['last_activity'] = time.time()
success = self._save_session_data(session_id, session_data)
if success:
print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
else:
print(f"WARNING: Failed to save scanner state for session {session_id}")
return success
else:
print(f"WARNING: Session {session_id} not found for scanner update")
return False
except Exception as e:
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
return False
def update_scanner_status(self, session_id: str, status: str) -> bool:
"""Quickly update scanner status for immediate GUI feedback."""
try:
session_data = self._get_session_data(session_id)
if session_data and 'scanner' in session_data:
session_data['scanner'].status = status
session_data['last_activity'] = time.time()
success = self._save_session_data(session_id, session_data)
if success:
print(f"Scanner status updated to '{status}' for session {session_id}")
else:
print(f"WARNING: Failed to save status update for session {session_id}")
return success
return False
except Exception as e:
print(f"ERROR: Failed to update scanner status for session {session_id}: {e}")
return False
def get_session(self, session_id: str) -> Optional[Scanner]:
"""Get scanner instance for session with session ID management."""
if not session_id:
return None
session_data = self._get_session_data(session_id)
if not session_data or session_data.get('status') != 'active':
return None
# Update last activity and save back to Redis
session_data['last_activity'] = time.time()
self._save_session_data(session_id, session_data)
scanner = session_data.get('scanner')
if scanner:
# Ensure scanner can check Redis-based stop signal
scanner.session_id = session_id
return scanner
def get_session_status_only(self, session_id: str) -> Optional[str]:
"""Get scanner status without full session retrieval (for performance)."""
try:
session_data = self._get_session_data(session_id)
if session_data and 'scanner' in session_data:
return session_data['scanner'].status
return None
except Exception as e:
print(f"ERROR: Failed to get session status for {session_id}: {e}")
return None
def terminate_session(self, session_id: str) -> bool:
"""Terminate specific session with reliable stop signal and immediate status update."""
return self._terminate_session_internal(session_id, cleanup_user_mapping=True)
def _terminate_session_internal(self, session_id: str, cleanup_user_mapping: bool = True) -> bool:
"""Internal session termination with configurable user mapping cleanup."""
print(f"=== TERMINATING SESSION {session_id} ===")
try:
# Set stop signal first
self.set_stop_signal(session_id)
# Update scanner status immediately for GUI feedback
self.update_scanner_status(session_id, 'stopped')
session_data = self._get_session_data(session_id)
if not session_data:
print(f"Session {session_id} not found")
return False
scanner = session_data.get('scanner')
if scanner and scanner.status == 'running':
print(f"Stopping scan for session: {session_id}")
scanner.stop_scan()
self.update_session_scanner(session_id, scanner)
# Wait for graceful shutdown
time.sleep(0.5)
# Clean up user mapping if requested
if cleanup_user_mapping and 'user_fingerprint' in session_data:
user_session_key = self._get_user_session_key(session_data['user_fingerprint'])
self.redis_client.delete(user_session_key)
print(f"Cleaned up user mapping for {session_data['user_fingerprint']}")
# Delete session data and stop signal
session_key = self._get_session_key(session_id)
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.delete(session_key)
self.redis_client.delete(stop_key)
print(f"Terminated and removed session from Redis: {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to terminate session {session_id}: {e}")
return False
def _cleanup_loop(self) -> None:
"""Background thread to cleanup inactive sessions and orphaned signals."""
while True:
try:
# Clean up orphaned stop signals
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
for stop_key in stop_keys:
session_id = stop_key.decode('utf-8').split(':')[-1]
session_key = self._get_session_key(session_id)
if not self.redis_client.exists(session_key):
self.redis_client.delete(stop_key)
print(f"Cleaned up orphaned stop signal for session {session_id}")
# Clean up orphaned user mappings
user_keys = self.redis_client.keys("dnsrecon:user:*")
for user_key in user_keys:
session_id_bytes = self.redis_client.get(user_key)
if session_id_bytes:
session_id = session_id_bytes.decode('utf-8')
session_key = self._get_session_key(session_id)
if not self.redis_client.exists(session_key):
self.redis_client.delete(user_key)
print(f"Cleaned up orphaned user mapping for session {session_id}")
except Exception as e:
print(f"Error in cleanup loop: {e}")
time.sleep(300) # Sleep for 5 minutes
def list_active_sessions(self) -> List[Dict[str, Any]]:
"""List all active sessions for admin purposes."""
try:
session_keys = self.redis_client.keys("dnsrecon:session:*")
sessions = []
for session_key in session_keys:
session_id = session_key.decode('utf-8').split(':')[-1]
session_data = self._get_session_data(session_id)
if session_data:
scanner = session_data.get('scanner')
sessions.append({
'session_id': session_id,
'user_fingerprint': session_data.get('user_fingerprint', 'unknown'),
'client_ip': session_data.get('client_ip', 'unknown'),
'created_at': session_data.get('created_at'),
'last_activity': session_data.get('last_activity'),
'scanner_status': scanner.status if scanner else 'unknown',
'current_target': scanner.current_target if scanner else None
})
return sessions
except Exception as e:
print(f"ERROR: Failed to list active sessions: {e}")
return []
def get_statistics(self) -> Dict[str, Any]:
"""Get session manager statistics."""
try:
session_keys = self.redis_client.keys("dnsrecon:session:*")
user_keys = self.redis_client.keys("dnsrecon:user:*")
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
active_sessions = len(session_keys)
unique_users = len(user_keys)
running_scans = 0
for session_key in session_keys:
session_id = session_key.decode('utf-8').split(':')[-1]
status = self.get_session_status_only(session_id)
if status == 'running':
running_scans += 1
return {
'total_active_sessions': active_sessions,
'unique_users': unique_users,
'running_scans': running_scans,
'total_stop_signals': len(stop_keys),
'average_sessions_per_user': round(active_sessions / unique_users, 2) if unique_users > 0 else 0
}
except Exception as e:
print(f"ERROR: Failed to get statistics: {e}")
return {
'total_active_sessions': 0,
'unique_users': 0,
'running_scans': 0,
'total_stop_signals': 0,
'average_sessions_per_user': 0
}
def get_session_info(self, session_id: str) -> Dict[str, Any]:
"""Get detailed information about a specific session."""
try:
session_data = self._get_session_data(session_id)
if not session_data:
return {'error': 'Session not found'}
scanner = session_data.get('scanner')
return {
'session_id': session_id,
'user_fingerprint': session_data.get('user_fingerprint', 'unknown'),
'client_ip': session_data.get('client_ip', 'unknown'),
'user_agent': session_data.get('user_agent', 'unknown'),
'created_at': session_data.get('created_at'),
'last_activity': session_data.get('last_activity'),
'status': session_data.get('status'),
'scanner_status': scanner.status if scanner else 'unknown',
'current_target': scanner.current_target if scanner else None,
'session_age_minutes': round((time.time() - session_data.get('created_at', time.time())) / 60, 1)
}
except Exception as e:
print(f"ERROR: Failed to get session info for {session_id}: {e}")
return {'error': f'Failed to get session info: {str(e)}'}
# Global session manager instance
session_manager = SessionManager(session_timeout_minutes=60)