412 lines
16 KiB
Python
412 lines
16 KiB
Python
# dnsrecon/core/session_manager.py
|
|
|
|
import threading
|
|
import time
|
|
import uuid
|
|
import redis
|
|
import pickle
|
|
from typing import Dict, Optional, Any, List
|
|
|
|
from core.scanner import Scanner
|
|
|
|
# WARNING: Using pickle can be a security risk if the data source is not trusted.
|
|
# In this case, we are only serializing/deserializing our own trusted Scanner objects,
|
|
# which is generally safe. Do not unpickle data from untrusted sources.
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages multiple scanner instances for concurrent user sessions using Redis.
|
|
"""
|
|
|
|
def __init__(self, session_timeout_minutes: int = 60):
|
|
"""
|
|
Initialize session manager with a Redis backend.
|
|
"""
|
|
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
|
|
self.session_timeout = session_timeout_minutes * 60 # Convert to seconds
|
|
self.lock = threading.Lock() # Lock for local operations, Redis handles atomic ops
|
|
|
|
# Start cleanup thread
|
|
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
|
self.cleanup_thread.start()
|
|
|
|
print(f"SessionManager initialized with Redis backend and {session_timeout_minutes}min timeout")
|
|
|
|
def __getstate__(self):
|
|
"""Prepare SessionManager for pickling."""
|
|
state = self.__dict__.copy()
|
|
# Exclude unpickleable attributes - Redis client and threading objects
|
|
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client']
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore SessionManager after unpickling."""
|
|
self.__dict__.update(state)
|
|
# Re-initialize unpickleable attributes
|
|
import redis
|
|
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
|
|
self.lock = threading.Lock()
|
|
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
|
self.cleanup_thread.start()
|
|
|
|
def _get_session_key(self, session_id: str) -> str:
|
|
"""Generates the Redis key for a session."""
|
|
return f"dnsrecon:session:{session_id}"
|
|
|
|
def _get_stop_signal_key(self, session_id: str) -> str:
|
|
"""Generates the Redis key for a session's stop signal."""
|
|
return f"dnsrecon:stop:{session_id}"
|
|
|
|
def create_session(self) -> str:
|
|
"""
|
|
Create a new user session and store it in Redis.
|
|
"""
|
|
session_id = str(uuid.uuid4())
|
|
print(f"=== CREATING SESSION {session_id} IN REDIS ===")
|
|
|
|
try:
|
|
from core.session_config import create_session_config
|
|
session_config = create_session_config()
|
|
scanner_instance = Scanner(session_config=session_config)
|
|
|
|
# Set the session ID on the scanner for cross-process stop signal management
|
|
scanner_instance.session_id = session_id
|
|
|
|
session_data = {
|
|
'scanner': scanner_instance,
|
|
'config': session_config,
|
|
'created_at': time.time(),
|
|
'last_activity': time.time(),
|
|
'status': 'active'
|
|
}
|
|
|
|
# Serialize the entire session data dictionary using pickle
|
|
serialized_data = pickle.dumps(session_data)
|
|
|
|
# Store in Redis
|
|
session_key = self._get_session_key(session_id)
|
|
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
|
|
|
# Initialize stop signal as False
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
self.redis_client.setex(stop_key, self.session_timeout, b'0')
|
|
|
|
print(f"Session {session_id} stored in Redis with stop signal initialized")
|
|
return session_id
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to create session {session_id}: {e}")
|
|
raise
|
|
|
|
def set_stop_signal(self, session_id: str) -> bool:
|
|
"""
|
|
Set the stop signal for a session (cross-process safe).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
bool: True if signal was set successfully
|
|
"""
|
|
try:
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
# Set stop signal to '1' with the same TTL as the session
|
|
self.redis_client.setex(stop_key, self.session_timeout, b'1')
|
|
print(f"Stop signal set for session {session_id}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to set stop signal for session {session_id}: {e}")
|
|
return False
|
|
|
|
def is_stop_requested(self, session_id: str) -> bool:
|
|
"""
|
|
Check if stop is requested for a session (cross-process safe).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
bool: True if stop is requested
|
|
"""
|
|
try:
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
value = self.redis_client.get(stop_key)
|
|
return value == b'1' if value is not None else False
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to check stop signal for session {session_id}: {e}")
|
|
return False
|
|
|
|
def clear_stop_signal(self, session_id: str) -> bool:
|
|
"""
|
|
Clear the stop signal for a session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
bool: True if signal was cleared successfully
|
|
"""
|
|
try:
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
self.redis_client.setex(stop_key, self.session_timeout, b'0')
|
|
print(f"Stop signal cleared for session {session_id}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}")
|
|
return False
|
|
|
|
def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieves and deserializes session data from Redis."""
|
|
try:
|
|
session_key = self._get_session_key(session_id)
|
|
serialized_data = self.redis_client.get(session_key)
|
|
if serialized_data:
|
|
session_data = pickle.loads(serialized_data)
|
|
# Ensure the scanner has the correct session ID for stop signal checking
|
|
if 'scanner' in session_data and session_data['scanner']:
|
|
session_data['scanner'].session_id = session_id
|
|
return session_data
|
|
return None
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get session data for {session_id}: {e}")
|
|
return None
|
|
|
|
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
|
"""
|
|
Serializes and saves session data back to Redis with updated TTL.
|
|
|
|
Returns:
|
|
bool: True if save was successful
|
|
"""
|
|
try:
|
|
session_key = self._get_session_key(session_id)
|
|
serialized_data = pickle.dumps(session_data)
|
|
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
|
return result
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to save session data for {session_id}: {e}")
|
|
return False
|
|
|
|
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
|
|
"""
|
|
Updates just the scanner object in a session with immediate persistence.
|
|
|
|
Returns:
|
|
bool: True if update was successful
|
|
"""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if session_data:
|
|
# Ensure scanner has the session ID
|
|
scanner.session_id = session_id
|
|
session_data['scanner'] = scanner
|
|
session_data['last_activity'] = time.time()
|
|
|
|
# Immediately save to Redis for GUI updates
|
|
success = self._save_session_data(session_id, session_data)
|
|
if success:
|
|
print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
|
else:
|
|
print(f"WARNING: Failed to save scanner state for session {session_id}")
|
|
return success
|
|
else:
|
|
print(f"WARNING: Session {session_id} not found for scanner update")
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
|
|
return False
|
|
|
|
def update_scanner_status(self, session_id: str, status: str) -> bool:
|
|
"""
|
|
Quickly update just the scanner status for immediate GUI feedback.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
status: New scanner status
|
|
|
|
Returns:
|
|
bool: True if update was successful
|
|
"""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if session_data and 'scanner' in session_data:
|
|
session_data['scanner'].status = status
|
|
session_data['last_activity'] = time.time()
|
|
|
|
success = self._save_session_data(session_id, session_data)
|
|
if success:
|
|
print(f"Scanner status updated to '{status}' for session {session_id}")
|
|
else:
|
|
print(f"WARNING: Failed to save status update for session {session_id}")
|
|
return success
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update scanner status for session {session_id}: {e}")
|
|
return False
|
|
|
|
def get_session(self, session_id: str) -> Optional[Scanner]:
|
|
"""
|
|
Get scanner instance for a session from Redis with session ID management.
|
|
"""
|
|
if not session_id:
|
|
return None
|
|
|
|
session_data = self._get_session_data(session_id)
|
|
|
|
if not session_data or session_data.get('status') != 'active':
|
|
return None
|
|
|
|
# Update last activity and save back to Redis
|
|
session_data['last_activity'] = time.time()
|
|
self._save_session_data(session_id, session_data)
|
|
|
|
scanner = session_data.get('scanner')
|
|
if scanner:
|
|
# Ensure the scanner can check the Redis-based stop signal
|
|
scanner.session_id = session_id
|
|
|
|
return scanner
|
|
|
|
def get_session_status_only(self, session_id: str) -> Optional[str]:
|
|
"""
|
|
Get just the scanner status without full session retrieval (for performance).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
Scanner status string or None if not found
|
|
"""
|
|
try:
|
|
session_data = self._get_session_data(session_id)
|
|
if session_data and 'scanner' in session_data:
|
|
return session_data['scanner'].status
|
|
return None
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get session status for {session_id}: {e}")
|
|
return None
|
|
|
|
def terminate_session(self, session_id: str) -> bool:
|
|
"""
|
|
Terminate a specific session in Redis with reliable stop signal and immediate status update.
|
|
"""
|
|
print(f"=== TERMINATING SESSION {session_id} ===")
|
|
|
|
try:
|
|
# First, set the stop signal
|
|
self.set_stop_signal(session_id)
|
|
|
|
# Update scanner status to stopped immediately for GUI feedback
|
|
self.update_scanner_status(session_id, 'stopped')
|
|
|
|
session_data = self._get_session_data(session_id)
|
|
if not session_data:
|
|
print(f"Session {session_id} not found")
|
|
return False
|
|
|
|
scanner = session_data.get('scanner')
|
|
if scanner and scanner.status == 'running':
|
|
print(f"Stopping scan for session: {session_id}")
|
|
# The scanner will check the Redis stop signal
|
|
scanner.stop_scan()
|
|
|
|
# Update the scanner state immediately
|
|
self.update_session_scanner(session_id, scanner)
|
|
|
|
# Wait a moment for graceful shutdown
|
|
time.sleep(0.5)
|
|
|
|
# Delete session data and stop signal from Redis
|
|
session_key = self._get_session_key(session_id)
|
|
stop_key = self._get_stop_signal_key(session_id)
|
|
self.redis_client.delete(session_key)
|
|
self.redis_client.delete(stop_key)
|
|
|
|
print(f"Terminated and removed session from Redis: {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to terminate session {session_id}: {e}")
|
|
return False
|
|
|
|
def _cleanup_loop(self) -> None:
|
|
"""
|
|
Background thread to cleanup inactive sessions and orphaned stop signals.
|
|
"""
|
|
while True:
|
|
try:
|
|
# Clean up orphaned stop signals
|
|
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
|
|
for stop_key in stop_keys:
|
|
# Extract session ID from stop key
|
|
session_id = stop_key.decode('utf-8').split(':')[-1]
|
|
session_key = self._get_session_key(session_id)
|
|
|
|
# If session doesn't exist but stop signal does, clean it up
|
|
if not self.redis_client.exists(session_key):
|
|
self.redis_client.delete(stop_key)
|
|
print(f"Cleaned up orphaned stop signal for session {session_id}")
|
|
|
|
except Exception as e:
|
|
print(f"Error in cleanup loop: {e}")
|
|
|
|
time.sleep(300) # Sleep for 5 minutes
|
|
|
|
def list_active_sessions(self) -> List[Dict[str, Any]]:
|
|
"""List all active sessions for admin purposes."""
|
|
try:
|
|
session_keys = self.redis_client.keys("dnsrecon:session:*")
|
|
sessions = []
|
|
|
|
for session_key in session_keys:
|
|
session_id = session_key.decode('utf-8').split(':')[-1]
|
|
session_data = self._get_session_data(session_id)
|
|
|
|
if session_data:
|
|
scanner = session_data.get('scanner')
|
|
sessions.append({
|
|
'session_id': session_id,
|
|
'created_at': session_data.get('created_at'),
|
|
'last_activity': session_data.get('last_activity'),
|
|
'scanner_status': scanner.status if scanner else 'unknown',
|
|
'current_target': scanner.current_target if scanner else None
|
|
})
|
|
|
|
return sessions
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to list active sessions: {e}")
|
|
return []
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get session manager statistics."""
|
|
try:
|
|
session_keys = self.redis_client.keys("dnsrecon:session:*")
|
|
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
|
|
|
|
active_sessions = len(session_keys)
|
|
running_scans = 0
|
|
|
|
for session_key in session_keys:
|
|
session_id = session_key.decode('utf-8').split(':')[-1]
|
|
status = self.get_session_status_only(session_id)
|
|
if status == 'running':
|
|
running_scans += 1
|
|
|
|
return {
|
|
'total_active_sessions': active_sessions,
|
|
'running_scans': running_scans,
|
|
'total_stop_signals': len(stop_keys)
|
|
}
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to get statistics: {e}")
|
|
return {
|
|
'total_active_sessions': 0,
|
|
'running_scans': 0,
|
|
'total_stop_signals': 0
|
|
}
|
|
|
|
# Global session manager instance
|
|
session_manager = SessionManager(session_timeout_minutes=60) |