576 lines
25 KiB
Python
576 lines
25 KiB
Python
# dnsrecon/core/session_manager.py
|
|
|
|
import threading
|
|
import time
|
|
import uuid
|
|
import redis
|
|
import pickle
|
|
import hashlib
|
|
from typing import Dict, Optional, Any, List, Tuple
|
|
|
|
from core.scanner import Scanner
|
|
|
|
|
|
class UserIdentifier:
|
|
"""Handles user identification for session management."""
|
|
|
|
@staticmethod
|
|
def generate_user_fingerprint(client_ip: str, user_agent: str) -> str:
|
|
"""
|
|
Generate a unique fingerprint for a user based on IP and User-Agent.
|
|
|
|
Args:
|
|
client_ip: Client IP address
|
|
user_agent: User-Agent header value
|
|
|
|
Returns:
|
|
Unique user fingerprint hash
|
|
"""
|
|
# Create deterministic user identifier
|
|
user_data = f"{client_ip}:{user_agent[:100]}" # Limit UA to 100 chars
|
|
fingerprint = hashlib.sha256(user_data.encode()).hexdigest()[:16] # 16 char fingerprint
|
|
return f"user_{fingerprint}"
|
|
|
|
@staticmethod
|
|
def extract_request_info(request) -> Tuple[str, str]:
|
|
"""
|
|
Extract client IP and User-Agent from Flask request.
|
|
|
|
Args:
|
|
request: Flask request object
|
|
|
|
Returns:
|
|
Tuple of (client_ip, user_agent)
|
|
"""
|
|
# Handle proxy headers for real IP
|
|
client_ip = request.headers.get('X-Forwarded-For', '').split(',')[0].strip()
|
|
if not client_ip:
|
|
client_ip = request.headers.get('X-Real-IP', '')
|
|
if not client_ip:
|
|
client_ip = request.remote_addr or 'unknown'
|
|
|
|
user_agent = request.headers.get('User-Agent', 'unknown')
|
|
|
|
return client_ip, user_agent
|
|
|
|
|
|
class SessionConsolidator:
|
|
"""Handles consolidation of session data when replacing sessions."""
|
|
|
|
@staticmethod
|
|
def consolidate_scanner_data(old_scanner: 'Scanner', new_scanner: 'Scanner') -> 'Scanner':
|
|
"""
|
|
Consolidate useful data from old scanner into new scanner.
|
|
|
|
Args:
|
|
old_scanner: Scanner from terminated session
|
|
new_scanner: New scanner instance
|
|
|
|
Returns:
|
|
Enhanced new scanner with consolidated data
|
|
"""
|
|
try:
|
|
# Consolidate graph data if old scanner has valuable data
|
|
if old_scanner and hasattr(old_scanner, 'graph') and old_scanner.graph:
|
|
old_stats = old_scanner.graph.get_statistics()
|
|
if old_stats['basic_metrics']['total_nodes'] > 0:
|
|
print(f"Consolidating graph data: {old_stats['basic_metrics']['total_nodes']} nodes, {old_stats['basic_metrics']['total_edges']} edges")
|
|
|
|
# Transfer nodes and edges to new scanner's graph
|
|
for node_id, node_data in old_scanner.graph.graph.nodes(data=True):
|
|
# Add node to new graph with all attributes
|
|
new_scanner.graph.graph.add_node(node_id, **node_data)
|
|
|
|
for source, target, edge_data in old_scanner.graph.graph.edges(data=True):
|
|
# Add edge to new graph with all attributes
|
|
new_scanner.graph.graph.add_edge(source, target, **edge_data)
|
|
|
|
# Update correlation index
|
|
if hasattr(old_scanner.graph, 'correlation_index'):
|
|
new_scanner.graph.correlation_index = old_scanner.graph.correlation_index.copy()
|
|
|
|
# Update timestamps
|
|
new_scanner.graph.creation_time = old_scanner.graph.creation_time
|
|
new_scanner.graph.last_modified = old_scanner.graph.last_modified
|
|
|
|
# Consolidate provider statistics
|
|
if old_scanner and hasattr(old_scanner, 'providers') and old_scanner.providers:
|
|
for old_provider in old_scanner.providers:
|
|
# Find matching provider in new scanner
|
|
matching_new_provider = None
|
|
for new_provider in new_scanner.providers:
|
|
if new_provider.get_name() == old_provider.get_name():
|
|
matching_new_provider = new_provider
|
|
break
|
|
|
|
if matching_new_provider:
|
|
# Transfer cumulative statistics
|
|
matching_new_provider.total_requests += old_provider.total_requests
|
|
matching_new_provider.successful_requests += old_provider.successful_requests
|
|
matching_new_provider.failed_requests += old_provider.failed_requests
|
|
matching_new_provider.total_relationships_found += old_provider.total_relationships_found
|
|
|
|
# Transfer cache statistics if available
|
|
if hasattr(old_provider, 'cache_hits'):
|
|
matching_new_provider.cache_hits += getattr(old_provider, 'cache_hits', 0)
|
|
matching_new_provider.cache_misses += getattr(old_provider, 'cache_misses', 0)
|
|
|
|
print(f"Consolidated {old_provider.get_name()} provider stats: {old_provider.total_requests} requests")
|
|
|
|
return new_scanner
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Error during session consolidation: {e}")
|
|
return new_scanner
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages single scanner session per user using Redis with user identification.
|
|
Enforces one active session per user for consistent state management.
|
|
"""
|
|
|
|
def __init__(self, session_timeout_minutes: int = 60):
|
|
"""
|
|
Initialize session manager with Redis backend and user tracking.
|
|
"""
|
|
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
|
|
self.session_timeout = session_timeout_minutes * 60 # Convert to seconds
|
|
self.lock = threading.Lock()
|
|
|
|
# User identification helper
|
|
self.user_identifier = UserIdentifier()
|
|
self.consolidator = SessionConsolidator()
|
|
|
|
# Start cleanup thread
|
|
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
|
self.cleanup_thread.start()
|
|
|
|
print(f"SessionManager initialized with Redis backend, user tracking, and {session_timeout_minutes}min timeout")
|
|
|
|
def __getstate__(self):
|
|
"""Prepare SessionManager for pickling."""
|
|
state = self.__dict__.copy()
|
|
# Exclude unpickleable attributes
|
|
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client']
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore SessionManager after unpickling."""
|
|
self.__dict__.update(state)
|
|
# Re-initialize unpickleable attributes
|
|
import redis
|
|
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
|
|
self.lock = threading.Lock()
|
|
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
|
self.cleanup_thread.start()
|
|
|
|
def _get_session_key(self, session_id: str) -> str:
|
|
"""Generate Redis key for a session."""
|
|
return f"dnsrecon:session:{session_id}"
|
|
|
|
def _get_user_session_key(self, user_fingerprint: str) -> str:
|
|
"""Generate Redis key for user -> session mapping."""
|
|
return f"dnsrecon:user:{user_fingerprint}"
|
|
|
|
def _get_stop_signal_key(self, session_id: str) -> str:
|
|
"""Generate Redis key for session stop signal."""
|
|
return f"dnsrecon:stop:{session_id}"
|
|
|
|
def create_or_replace_user_session(self, client_ip: str, user_agent: str) -> str:
|
|
"""
|
|
Create new session for user, replacing any existing session.
|
|
Consolidates data from previous session if it exists.
|
|
|
|
Args:
|
|
client_ip: Client IP address
|
|
user_agent: User-Agent header
|
|
|
|
Returns:
|
|
New session ID
|
|
"""
|
|
user_fingerprint = self.user_identifier.generate_user_fingerprint(client_ip, user_agent)
|
|
new_session_id = str(uuid.uuid4())
|
|
|
|
print(f"=== CREATING/REPLACING SESSION FOR USER {user_fingerprint} ===")
|
|
|
|
try:
|
|
# Check for existing user session
|
|
existing_session_id = self._get_user_current_session(user_fingerprint)
|
|
old_scanner = None
|
|
|
|
if existing_session_id:
|
|
print(f"Found existing session {existing_session_id} for user {user_fingerprint}")
|
|
# Get old scanner data for consolidation
|
|
old_scanner = self.get_session(existing_session_id)
|
|
# Terminate old session
|
|
self._terminate_session_internal(existing_session_id, cleanup_user_mapping=False)
|
|
print(f"Terminated old session {existing_session_id}")
|
|
|
|
# Create new session config and scanner
|
|
from core.session_config import create_session_config
|
|
session_config = create_session_config()
|
|
new_scanner = Scanner(session_config=session_config)
|
|
|
|
# Set session ID on scanner for cross-process operations
|
|
new_scanner.session_id = new_session_id
|
|
|
|
# Consolidate data from old session if available
|
|
if old_scanner:
|
|
new_scanner = self.consolidator.consolidate_scanner_data(old_scanner, new_scanner)
|
|
print(f"Consolidated data from previous session")
|
|
|
|
# Create session data
|
|
session_data = {
|
|
'scanner': new_scanner,
|
|
'config': session_config,
|
|
'created_at': time.time(),
|
|
'last_activity': time.time(),
|
|
'status': 'active',
|
|
'user_fingerprint': user_fingerprint,
|
|
'client_ip': client_ip,
|
|
'user_agent': user_agent[:200] # Truncate for storage
|
|
}
|
|
|
|
# Store session in Redis
|
|
session_key = self._get_session_key(new_session_id)
|
|
serialized_data = pickle.dumps(session_data)
|
|
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
|
|
|
# Update user -> session mapping
|
|
user_session_key = self._get_user_session_key(user_fingerprint)
|
|
self.redis_client.setex(user_session_key, self.session_timeout, new_session_id.encode('utf-8'))
|
|
|
|
# Initialize stop signal
|
|
stop_key = self._get_stop_signal_key(new_session_id)
|
|
self.redis_client.setex(stop_key, self.session_timeout, b'0')
|
|
|
|
print(f"Created new session {new_session_id} for user {user_fingerprint}")
|
|
return new_session_id
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to create session for user {user_fingerprint}: {e}")
|
|
raise
|
|
|
|
def _get_user_current_session(self, user_fingerprint: str) -> Optional[str]:
|
|
"""Get current session ID for a user."""
|
|
try:
|
|
user_session_key = self._get_user_session_key(user_fingerprint)
|
|
session_id_bytes = self.redis_client.get(user_session_key)
|
|
if session_id_bytes:
|
|
return session_id_bytes.decode('utf-8')
|
|
return None
|
|
except Exception as e:
|
|
print(f"Error getting user session: {e}")
|
|
return None
|
|
|
|
def set_stop_signal(self, session_id: str) -> bool:
|
|
"""Set stop signal for session (cross-process safe)."""
|
|
try:
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
self.redis_client.setex(stop_key, self.session_timeout, b'1')
|
|
print(f"Stop signal set for session {session_id}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to set stop signal for session {session_id}: {e}")
|
|
return False
|
|
|
|
def is_stop_requested(self, session_id: str) -> bool:
|
|
"""Check if stop is requested for session (cross-process safe)."""
|
|
try:
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
value = self.redis_client.get(stop_key)
|
|
return value == b'1' if value is not None else False
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to check stop signal for session {session_id}: {e}")
|
|
return False
|
|
|
|
def clear_stop_signal(self, session_id: str) -> bool:
|
|
"""Clear stop signal for session."""
|
|
try:
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
self.redis_client.setex(stop_key, self.session_timeout, b'0')
|
|
print(f"Stop signal cleared for session {session_id}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}")
|
|
return False
|
|
|
|
def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve and deserialize session data from Redis."""
|
|
try:
|
|
session_key = self._get_session_key(session_id)
|
|
serialized_data = self.redis_client.get(session_key)
|
|
if serialized_data:
|
|
session_data = pickle.loads(serialized_data)
|
|
# Ensure scanner has correct session ID
|
|
if 'scanner' in session_data and session_data['scanner']:
|
|
session_data['scanner'].session_id = session_id
|
|
return session_data
|
|
return None
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get session data for {session_id}: {e}")
|
|
return None
|
|
|
|
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
|
"""Serialize and save session data to Redis with updated TTL."""
|
|
try:
|
|
session_key = self._get_session_key(session_id)
|
|
serialized_data = pickle.dumps(session_data)
|
|
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
|
|
|
# Also refresh user mapping TTL if available
|
|
if 'user_fingerprint' in session_data:
|
|
user_session_key = self._get_user_session_key(session_data['user_fingerprint'])
|
|
self.redis_client.setex(user_session_key, self.session_timeout, session_id.encode('utf-8'))
|
|
|
|
return result
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to save session data for {session_id}: {e}")
|
|
return False
|
|
|
|
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
|
|
"""Update scanner object in session with immediate persistence."""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if session_data:
|
|
# Ensure scanner has session ID
|
|
scanner.session_id = session_id
|
|
session_data['scanner'] = scanner
|
|
session_data['last_activity'] = time.time()
|
|
|
|
success = self._save_session_data(session_id, session_data)
|
|
if success:
|
|
print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
|
else:
|
|
print(f"WARNING: Failed to save scanner state for session {session_id}")
|
|
return success
|
|
else:
|
|
print(f"WARNING: Session {session_id} not found for scanner update")
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
|
|
return False
|
|
|
|
def update_scanner_status(self, session_id: str, status: str) -> bool:
|
|
"""Quickly update scanner status for immediate GUI feedback."""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if session_data and 'scanner' in session_data:
|
|
session_data['scanner'].status = status
|
|
session_data['last_activity'] = time.time()
|
|
|
|
success = self._save_session_data(session_id, session_data)
|
|
if success:
|
|
print(f"Scanner status updated to '{status}' for session {session_id}")
|
|
else:
|
|
print(f"WARNING: Failed to save status update for session {session_id}")
|
|
return success
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update scanner status for session {session_id}: {e}")
|
|
return False
|
|
|
|
def get_session(self, session_id: str) -> Optional[Scanner]:
|
|
"""Get scanner instance for session with session ID management."""
|
|
if not session_id:
|
|
return None
|
|
|
|
session_data = self._get_session_data(session_id)
|
|
|
|
if not session_data or session_data.get('status') != 'active':
|
|
return None
|
|
|
|
# Update last activity and save back to Redis
|
|
session_data['last_activity'] = time.time()
|
|
self._save_session_data(session_id, session_data)
|
|
|
|
scanner = session_data.get('scanner')
|
|
if scanner:
|
|
# Ensure scanner can check Redis-based stop signal
|
|
scanner.session_id = session_id
|
|
|
|
return scanner
|
|
|
|
def get_session_status_only(self, session_id: str) -> Optional[str]:
|
|
"""Get scanner status without full session retrieval (for performance)."""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if session_data and 'scanner' in session_data:
|
|
return session_data['scanner'].status
|
|
return None
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get session status for {session_id}: {e}")
|
|
return None
|
|
|
|
def terminate_session(self, session_id: str) -> bool:
|
|
"""Terminate specific session with reliable stop signal and immediate status update."""
|
|
return self._terminate_session_internal(session_id, cleanup_user_mapping=True)
|
|
|
|
def _terminate_session_internal(self, session_id: str, cleanup_user_mapping: bool = True) -> bool:
|
|
"""Internal session termination with configurable user mapping cleanup."""
|
|
print(f"=== TERMINATING SESSION {session_id} ===")
|
|
|
|
try:
|
|
# Set stop signal first
|
|
self.set_stop_signal(session_id)
|
|
|
|
# Update scanner status immediately for GUI feedback
|
|
self.update_scanner_status(session_id, 'stopped')
|
|
|
|
session_data = self._get_session_data(session_id)
|
|
if not session_data:
|
|
print(f"Session {session_id} not found")
|
|
return False
|
|
|
|
scanner = session_data.get('scanner')
|
|
if scanner and scanner.status == 'running':
|
|
print(f"Stopping scan for session: {session_id}")
|
|
scanner.stop_scan()
|
|
self.update_session_scanner(session_id, scanner)
|
|
|
|
# Wait for graceful shutdown
|
|
time.sleep(0.5)
|
|
|
|
# Clean up user mapping if requested
|
|
if cleanup_user_mapping and 'user_fingerprint' in session_data:
|
|
user_session_key = self._get_user_session_key(session_data['user_fingerprint'])
|
|
self.redis_client.delete(user_session_key)
|
|
print(f"Cleaned up user mapping for {session_data['user_fingerprint']}")
|
|
|
|
# Delete session data and stop signal
|
|
session_key = self._get_session_key(session_id)
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
self.redis_client.delete(session_key)
|
|
self.redis_client.delete(stop_key)
|
|
|
|
print(f"Terminated and removed session from Redis: {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to terminate session {session_id}: {e}")
|
|
return False
|
|
|
|
def _cleanup_loop(self) -> None:
|
|
"""Background thread to cleanup inactive sessions and orphaned signals."""
|
|
while True:
|
|
try:
|
|
# Clean up orphaned stop signals
|
|
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
|
|
for stop_key in stop_keys:
|
|
session_id = stop_key.decode('utf-8').split(':')[-1]
|
|
session_key = self._get_session_key(session_id)
|
|
|
|
if not self.redis_client.exists(session_key):
|
|
self.redis_client.delete(stop_key)
|
|
print(f"Cleaned up orphaned stop signal for session {session_id}")
|
|
|
|
# Clean up orphaned user mappings
|
|
user_keys = self.redis_client.keys("dnsrecon:user:*")
|
|
for user_key in user_keys:
|
|
session_id_bytes = self.redis_client.get(user_key)
|
|
if session_id_bytes:
|
|
session_id = session_id_bytes.decode('utf-8')
|
|
session_key = self._get_session_key(session_id)
|
|
|
|
if not self.redis_client.exists(session_key):
|
|
self.redis_client.delete(user_key)
|
|
print(f"Cleaned up orphaned user mapping for session {session_id}")
|
|
|
|
except Exception as e:
|
|
print(f"Error in cleanup loop: {e}")
|
|
|
|
time.sleep(300) # Sleep for 5 minutes
|
|
|
|
def list_active_sessions(self) -> List[Dict[str, Any]]:
|
|
"""List all active sessions for admin purposes."""
|
|
try:
|
|
session_keys = self.redis_client.keys("dnsrecon:session:*")
|
|
sessions = []
|
|
|
|
for session_key in session_keys:
|
|
session_id = session_key.decode('utf-8').split(':')[-1]
|
|
session_data = self._get_session_data(session_id)
|
|
|
|
if session_data:
|
|
scanner = session_data.get('scanner')
|
|
sessions.append({
|
|
'session_id': session_id,
|
|
'user_fingerprint': session_data.get('user_fingerprint', 'unknown'),
|
|
'client_ip': session_data.get('client_ip', 'unknown'),
|
|
'created_at': session_data.get('created_at'),
|
|
'last_activity': session_data.get('last_activity'),
|
|
'scanner_status': scanner.status if scanner else 'unknown',
|
|
'current_target': scanner.current_target if scanner else None
|
|
})
|
|
|
|
return sessions
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to list active sessions: {e}")
|
|
return []
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get session manager statistics."""
|
|
try:
|
|
session_keys = self.redis_client.keys("dnsrecon:session:*")
|
|
user_keys = self.redis_client.keys("dnsrecon:user:*")
|
|
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
|
|
|
|
active_sessions = len(session_keys)
|
|
unique_users = len(user_keys)
|
|
running_scans = 0
|
|
|
|
for session_key in session_keys:
|
|
session_id = session_key.decode('utf-8').split(':')[-1]
|
|
status = self.get_session_status_only(session_id)
|
|
if status == 'running':
|
|
running_scans += 1
|
|
|
|
return {
|
|
'total_active_sessions': active_sessions,
|
|
'unique_users': unique_users,
|
|
'running_scans': running_scans,
|
|
'total_stop_signals': len(stop_keys),
|
|
'average_sessions_per_user': round(active_sessions / unique_users, 2) if unique_users > 0 else 0
|
|
}
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get statistics: {e}")
|
|
return {
|
|
'total_active_sessions': 0,
|
|
'unique_users': 0,
|
|
'running_scans': 0,
|
|
'total_stop_signals': 0,
|
|
'average_sessions_per_user': 0
|
|
}
|
|
|
|
def get_session_info(self, session_id: str) -> Dict[str, Any]:
|
|
"""Get detailed information about a specific session."""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if not session_data:
|
|
return {'error': 'Session not found'}
|
|
|
|
scanner = session_data.get('scanner')
|
|
|
|
return {
|
|
'session_id': session_id,
|
|
'user_fingerprint': session_data.get('user_fingerprint', 'unknown'),
|
|
'client_ip': session_data.get('client_ip', 'unknown'),
|
|
'user_agent': session_data.get('user_agent', 'unknown'),
|
|
'created_at': session_data.get('created_at'),
|
|
'last_activity': session_data.get('last_activity'),
|
|
'status': session_data.get('status'),
|
|
'scanner_status': scanner.status if scanner else 'unknown',
|
|
'current_target': scanner.current_target if scanner else None,
|
|
'session_age_minutes': round((time.time() - session_data.get('created_at', time.time())) / 60, 1)
|
|
}
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get session info for {session_id}: {e}")
|
|
return {'error': f'Failed to get session info: {str(e)}'}
|
|
|
|
|
|
# Global session manager instance
|
|
session_manager = SessionManager(session_timeout_minutes=60) |