dnscope/core/session_manager.py
2025-09-20 16:52:05 +02:00

533 lines
22 KiB
Python

# dnsrecon/core/session_manager.py
import threading
import time
import uuid
import redis
import pickle
from typing import Dict, Optional, Any
import copy
from core.scanner import Scanner
from config import config
class SessionManager:
"""
FIXED: Manages multiple scanner instances for concurrent user sessions using Redis.
Enhanced to properly maintain WebSocket connections throughout scan lifecycle.
"""
def __init__(self, session_timeout_minutes: int = 0):
"""
Initialize session manager with a Redis backend.
"""
if session_timeout_minutes is None:
session_timeout_minutes = config.session_timeout_minutes
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.session_timeout = session_timeout_minutes * 60 # Convert to seconds
self.lock = threading.Lock()
# FIXED: Add a creation lock to prevent race conditions
self.creation_lock = threading.Lock()
# Track active socketio connections per session
self.active_socketio_connections = {}
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
print(f"SessionManager initialized with Redis backend and {session_timeout_minutes}min timeout")
def __getstate__(self):
"""Prepare SessionManager for pickling."""
state = self.__dict__.copy()
# Exclude unpickleable attributes - Redis client and threading objects
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
return state
def __setstate__(self, state):
"""Restore SessionManager after unpickling."""
self.__dict__.update(state)
# Re-initialize unpickleable attributes
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.lock = threading.Lock()
self.creation_lock = threading.Lock()
self.active_socketio_connections = {}
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
def _get_session_key(self, session_id: str) -> str:
"""Generates the Redis key for a session."""
return f"dnsrecon:session:{session_id}"
def _get_stop_signal_key(self, session_id: str) -> str:
"""Generates the Redis key for a session's stop signal."""
return f"dnsrecon:stop:{session_id}"
def register_socketio_connection(self, session_id: str, socketio) -> None:
"""
FIXED: Register a socketio connection for a session.
This ensures the connection is maintained throughout the session lifecycle.
"""
with self.lock:
self.active_socketio_connections[session_id] = socketio
print(f"Registered socketio connection for session {session_id}")
def get_socketio_connection(self, session_id: str):
"""
FIXED: Get the active socketio connection for a session.
"""
with self.lock:
return self.active_socketio_connections.get(session_id)
def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner:
"""
FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects.
Now preserves socketio connection info for restoration.
"""
# Set the session ID on the scanner for cross-process stop signal management
scanner.session_id = session_id
# FIXED: Don't set socketio to None if we want to preserve real-time updates
# Instead, we'll restore it when loading the scanner
scanner.socketio = None
# Force cleanup of any threading objects that might cause issues
if hasattr(scanner, 'stop_event'):
scanner.stop_event = None
if hasattr(scanner, 'scan_thread'):
scanner.scan_thread = None
if hasattr(scanner, 'executor'):
scanner.executor = None
if hasattr(scanner, 'status_logger_thread'):
scanner.status_logger_thread = None
if hasattr(scanner, 'status_logger_stop_event'):
scanner.status_logger_stop_event = None
return scanner
def create_session(self, socketio=None) -> str:
"""
FIXED: Create a new user session with enhanced WebSocket management.
"""
# FIXED: Use creation lock to prevent race conditions
with self.creation_lock:
session_id = str(uuid.uuid4())
print(f"=== CREATING SESSION {session_id} IN REDIS ===")
# FIXED: Register socketio connection first
if socketio:
self.register_socketio_connection(session_id, socketio)
try:
from core.session_config import create_session_config
session_config = create_session_config()
# Create scanner WITHOUT socketio to avoid weakref issues
scanner_instance = Scanner(session_config=session_config, socketio=None)
# Prepare scanner for storage (removes problematic objects)
scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id)
session_data = {
'scanner': scanner_instance,
'config': session_config,
'created_at': time.time(),
'last_activity': time.time(),
'status': 'active'
}
# Test serialization before storing to catch issues early
try:
test_serialization = pickle.dumps(session_data)
print(f"Session serialization test successful ({len(test_serialization)} bytes)")
except Exception as pickle_error:
print(f"PICKLE TEST FAILED: {pickle_error}")
# Try to identify the problematic object
for key, value in session_data.items():
try:
pickle.dumps(value)
print(f" {key}: OK")
except Exception as item_error:
print(f" {key}: FAILED - {item_error}")
raise pickle_error
# Store in Redis
session_key = self._get_session_key(session_id)
self.redis_client.setex(session_key, self.session_timeout, test_serialization)
# Initialize stop signal as False
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'0')
print(f"Session {session_id} stored in Redis with stop signal initialized")
print(f"Session has {len(scanner_instance.providers)} providers: {[p.get_name() for p in scanner_instance.providers]}")
return session_id
except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}")
import traceback
traceback.print_exc()
raise
def set_stop_signal(self, session_id: str) -> bool:
"""
Set the stop signal for a session (cross-process safe).
Args:
session_id: Session identifier
Returns:
bool: True if signal was set successfully
"""
try:
stop_key = self._get_stop_signal_key(session_id)
# Set stop signal to '1' with the same TTL as the session
self.redis_client.setex(stop_key, self.session_timeout, b'1')
print(f"Stop signal set for session {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to set stop signal for session {session_id}: {e}")
return False
def is_stop_requested(self, session_id: str) -> bool:
"""
Check if stop is requested for a session (cross-process safe).
Args:
session_id: Session identifier
Returns:
bool: True if stop is requested
"""
try:
stop_key = self._get_stop_signal_key(session_id)
value = self.redis_client.get(stop_key)
return value == b'1' if value is not None else False
except Exception as e:
print(f"ERROR: Failed to check stop signal for session {session_id}: {e}")
return False
def clear_stop_signal(self, session_id: str) -> bool:
"""
Clear the stop signal for a session.
Args:
session_id: Session identifier
Returns:
bool: True if signal was cleared successfully
"""
try:
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'0')
print(f"Stop signal cleared for session {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}")
return False
def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Retrieves and deserializes session data from Redis."""
try:
session_key = self._get_session_key(session_id)
serialized_data = self.redis_client.get(session_key)
if serialized_data:
session_data = pickle.loads(serialized_data)
# Ensure the scanner has the correct session ID for stop signal checking
if 'scanner' in session_data and session_data['scanner']:
session_data['scanner'].session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
session_data['scanner'].socketio = socketio_conn
print(f"Restored socketio connection for session {session_id}")
else:
print(f"No socketio connection found for session {session_id}")
session_data['scanner'].socketio = None
return session_data
return None
except Exception as e:
print(f"ERROR: Failed to get session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return None
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
"""
Serializes and saves session data back to Redis with updated TTL.
FIXED: Now preserves socketio connection during storage.
Returns:
bool: True if save was successful
"""
try:
session_key = self._get_session_key(session_id)
# Create a deep copy to avoid modifying the original scanner object
session_data_to_save = copy.deepcopy(session_data)
# Prepare scanner for storage if it exists
if 'scanner' in session_data_to_save and session_data_to_save['scanner']:
# FIXED: Preserve the original socketio connection before preparing for storage
original_socketio = session_data_to_save['scanner'].socketio
session_data_to_save['scanner'] = self._prepare_scanner_for_storage(
session_data_to_save['scanner'],
session_id
)
# FIXED: If we had a socketio connection, make sure it's registered
if original_socketio and session_id not in self.active_socketio_connections:
self.register_socketio_connection(session_id, original_socketio)
serialized_data = pickle.dumps(session_data_to_save)
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
return result
except Exception as e:
print(f"ERROR: Failed to save session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
"""
FIXED: Updates just the scanner object in a session with immediate persistence.
Now maintains socketio connection throughout the update process.
Returns:
bool: True if update was successful
"""
try:
session_data = self._get_session_data(session_id)
if session_data:
# FIXED: Preserve socketio connection before preparing for storage
original_socketio = scanner.socketio
# Prepare scanner for storage
scanner = self._prepare_scanner_for_storage(scanner, session_id)
session_data['scanner'] = scanner
session_data['last_activity'] = time.time()
# FIXED: Restore socketio connection after preparation
if original_socketio:
self.register_socketio_connection(session_id, original_socketio)
session_data['scanner'].socketio = original_socketio
# Immediately save to Redis for GUI updates
success = self._save_session_data(session_id, session_data)
if success:
# Only log occasionally to reduce noise
if hasattr(self, '_last_update_log'):
if time.time() - self._last_update_log > 5: # Log every 5 seconds max
self._last_update_log = time.time()
else:
self._last_update_log = time.time()
else:
print(f"WARNING: Failed to save scanner state for session {session_id}")
return success
else:
print(f"WARNING: Session {session_id} not found for scanner update")
return False
except Exception as e:
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def update_scanner_status(self, session_id: str, status: str) -> bool:
"""
Quickly update just the scanner status for immediate GUI feedback.
Args:
session_id: Session identifier
status: New scanner status
Returns:
bool: True if update was successful
"""
try:
session_data = self._get_session_data(session_id)
if session_data and 'scanner' in session_data:
session_data['scanner'].status = status
session_data['last_activity'] = time.time()
success = self._save_session_data(session_id, session_data)
if success:
print(f"Scanner status updated to '{status}' for session {session_id}")
else:
print(f"WARNING: Failed to save status update for session {session_id}")
return success
return False
except Exception as e:
print(f"ERROR: Failed to update scanner status for session {session_id}: {e}")
return False
def get_session(self, session_id: str) -> Optional[Scanner]:
"""
FIXED: Get scanner instance for a session from Redis with proper socketio restoration.
"""
if not session_id:
return None
session_data = self._get_session_data(session_id)
if not session_data or session_data.get('status') != 'active':
return None
# Update last activity and save back to Redis
session_data['last_activity'] = time.time()
self._save_session_data(session_id, session_data)
scanner = session_data.get('scanner')
if scanner:
# Ensure the scanner can check the Redis-based stop signal
scanner.session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
scanner.socketio = socketio_conn
print(f"✓ Restored socketio connection for session {session_id}")
else:
scanner.socketio = None
print(f"⚠️ No socketio connection found for session {session_id}")
return scanner
def get_session_status_only(self, session_id: str) -> Optional[str]:
"""
Get just the scanner status without full session retrieval (for performance).
Args:
session_id: Session identifier
Returns:
Scanner status string or None if not found
"""
try:
session_data = self._get_session_data(session_id)
if session_data and 'scanner' in session_data:
return session_data['scanner'].status
return None
except Exception as e:
print(f"ERROR: Failed to get session status for {session_id}: {e}")
return None
def terminate_session(self, session_id: str) -> bool:
"""
Terminate a specific session in Redis with reliable stop signal and immediate status update.
"""
print(f"=== TERMINATING SESSION {session_id} ===")
try:
# First, set the stop signal
self.set_stop_signal(session_id)
# Update scanner status to stopped immediately for GUI feedback
self.update_scanner_status(session_id, 'stopped')
session_data = self._get_session_data(session_id)
if not session_data:
print(f"Session {session_id} not found")
return False
scanner = session_data.get('scanner')
if scanner and scanner.status == 'running':
print(f"Stopping scan for session: {session_id}")
# The scanner will check the Redis stop signal
scanner.stop_scan()
# Update the scanner state immediately
self.update_session_scanner(session_id, scanner)
# Wait a moment for graceful shutdown
time.sleep(0.5)
# FIXED: Clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up socketio connection for session {session_id}")
# Delete session data and stop signal from Redis
session_key = self._get_session_key(session_id)
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.delete(session_key)
self.redis_client.delete(stop_key)
print(f"Terminated and removed session from Redis: {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to terminate session {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def _cleanup_loop(self) -> None:
"""
Background thread to cleanup inactive sessions and orphaned stop signals.
"""
while True:
try:
# Clean up orphaned stop signals
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
for stop_key in stop_keys:
# Extract session ID from stop key
session_id = stop_key.decode('utf-8').split(':')[-1]
session_key = self._get_session_key(session_id)
# If session doesn't exist but stop signal does, clean it up
if not self.redis_client.exists(session_key):
self.redis_client.delete(stop_key)
print(f"Cleaned up orphaned stop signal for session {session_id}")
# Also clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up orphaned socketio for session {session_id}")
except Exception as e:
print(f"Error in cleanup loop: {e}")
time.sleep(300) # Sleep for 5 minutes
def get_statistics(self) -> Dict[str, Any]:
"""Get session manager statistics."""
try:
session_keys = self.redis_client.keys("dnsrecon:session:*")
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
active_sessions = len(session_keys)
running_scans = 0
for session_key in session_keys:
session_id = session_key.decode('utf-8').split(':')[-1]
status = self.get_session_status_only(session_id)
if status == 'running':
running_scans += 1
return {
'total_active_sessions': active_sessions,
'running_scans': running_scans,
'total_stop_signals': len(stop_keys),
'active_socketio_connections': len(self.active_socketio_connections)
}
except Exception as e:
print(f"ERROR: Failed to get statistics: {e}")
return {
'total_active_sessions': 0,
'running_scans': 0,
'total_stop_signals': 0,
'active_socketio_connections': 0
}
# Global session manager instance
session_manager = SessionManager(session_timeout_minutes=60)