iteration on ws implementation

This commit is contained in:
overcuriousity
2025-09-20 16:52:05 +02:00
parent 75a595c9cb
commit c4e6a8998a
9 changed files with 1224 additions and 290 deletions

View File

@@ -4,8 +4,7 @@
Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata.
Now fully compatible with the unified ProviderResult data model.
UPDATED: Fixed correlation exclusion keys to match actual attribute names.
UPDATED: Removed export_json() method - now handled by ExportManager.
FIXED: Added proper pickle support to prevent weakref serialization errors.
"""
import re
from datetime import datetime, timezone
@@ -33,6 +32,7 @@ class GraphManager:
Thread-safe graph manager for DNSRecon infrastructure mapping.
Uses NetworkX for in-memory graph storage with confidence scoring.
Compatible with unified ProviderResult data model.
FIXED: Added proper pickle support to handle NetworkX graph serialization.
"""
def __init__(self):
@@ -40,6 +40,57 @@ class GraphManager:
self.graph = nx.DiGraph()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time
def __getstate__(self):
"""Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
state = self.__dict__.copy()
# Convert NetworkX graph to a serializable format
if hasattr(self, 'graph') and self.graph:
# Extract all nodes with their data
nodes_data = {}
for node_id, attrs in self.graph.nodes(data=True):
nodes_data[node_id] = dict(attrs)
# Extract all edges with their data
edges_data = []
for source, target, attrs in self.graph.edges(data=True):
edges_data.append({
'source': source,
'target': target,
'attributes': dict(attrs)
})
# Replace the NetworkX graph with serializable data
state['_graph_nodes'] = nodes_data
state['_graph_edges'] = edges_data
del state['graph']
return state
def __setstate__(self, state):
"""Restore GraphManager after unpickling by reconstructing NetworkX graph."""
# Restore basic attributes
self.__dict__.update(state)
# Reconstruct NetworkX graph from serializable data
self.graph = nx.DiGraph()
# Restore nodes
if hasattr(self, '_graph_nodes'):
for node_id, attrs in self._graph_nodes.items():
self.graph.add_node(node_id, **attrs)
del self._graph_nodes
# Restore edges
if hasattr(self, '_graph_edges'):
for edge_data in self._graph_edges:
self.graph.add_edge(
edge_data['source'],
edge_data['target'],
**edge_data['attributes']
)
del self._graph_edges
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:

View File

@@ -40,6 +40,7 @@ class ForensicLogger:
"""
Thread-safe forensic logging system for DNSRecon.
Maintains detailed audit trail of all reconnaissance activities.
FIXED: Enhanced pickle support to prevent weakref issues in logging handlers.
"""
def __init__(self, session_id: str = ""):
@@ -65,45 +66,74 @@ class ForensicLogger:
'target_domains': set()
}
# Configure standard logger
# Configure standard logger with simple setup to avoid weakrefs
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
self.logger.setLevel(logging.INFO)
# Create formatter for structured logging
# Create minimal formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Add console handler if not already present
# Add console handler only if not already present (avoid duplicate handlers)
if not self.logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
def __getstate__(self):
"""Prepare ForensicLogger for pickling by excluding unpicklable objects."""
"""
FIXED: Prepare ForensicLogger for pickling by excluding problematic objects.
"""
state = self.__dict__.copy()
# Remove the unpickleable 'logger' attribute
if 'logger' in state:
del state['logger']
if 'lock' in state:
del state['lock']
# Remove potentially unpickleable attributes that may contain weakrefs
unpicklable_attrs = ['logger', 'lock']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
# Convert sets to lists for JSON serialization compatibility
if 'session_metadata' in state:
metadata = state['session_metadata'].copy()
if 'providers_used' in metadata and isinstance(metadata['providers_used'], set):
metadata['providers_used'] = list(metadata['providers_used'])
if 'target_domains' in metadata and isinstance(metadata['target_domains'], set):
metadata['target_domains'] = list(metadata['target_domains'])
state['session_metadata'] = metadata
return state
def __setstate__(self, state):
"""Restore ForensicLogger after unpickling by reconstructing logger."""
"""
FIXED: Restore ForensicLogger after unpickling by reconstructing components.
"""
self.__dict__.update(state)
# Re-initialize the 'logger' attribute
# Re-initialize threading lock
self.lock = threading.Lock()
# Re-initialize logger with minimal setup
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Only add handler if not already present
if not self.logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.lock = threading.Lock()
# Convert lists back to sets if needed
if 'session_metadata' in self.__dict__:
metadata = self.session_metadata
if 'providers_used' in metadata and isinstance(metadata['providers_used'], list):
metadata['providers_used'] = set(metadata['providers_used'])
if 'target_domains' in metadata and isinstance(metadata['target_domains'], list):
metadata['target_domains'] = set(metadata['target_domains'])
def _generate_session_id(self) -> str:
"""Generate unique session identifier."""
@@ -143,18 +173,23 @@ class ForensicLogger:
discovery_context=discovery_context
)
self.api_requests.append(api_request)
self.session_metadata['total_requests'] += 1
self.session_metadata['providers_used'].add(provider)
with self.lock:
self.api_requests.append(api_request)
self.session_metadata['total_requests'] += 1
self.session_metadata['providers_used'].add(provider)
if target_indicator:
self.session_metadata['target_domains'].add(target_indicator)
if target_indicator:
self.session_metadata['target_domains'].add(target_indicator)
# Log to standard logger
if error:
self.logger.error(f"API Request Failed.")
else:
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
# Log to standard logger with error handling
try:
if error:
self.logger.error(f"API Request Failed - {provider}: {url}")
else:
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
except Exception:
# If logging fails, continue without breaking the application
pass
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str, confidence_score: float,
@@ -183,29 +218,44 @@ class ForensicLogger:
discovery_method=discovery_method
)
self.relationships.append(relationship)
self.session_metadata['total_relationships'] += 1
with self.lock:
self.relationships.append(relationship)
self.session_metadata['total_relationships'] += 1
self.logger.info(
f"Relationship Discovered - {source_node} -> {target_node} "
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
)
# Log to standard logger with error handling
try:
self.logger.info(
f"Relationship Discovered - {source_node} -> {target_node} "
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
)
except Exception:
# If logging fails, continue without breaking the application
pass
def log_scan_start(self, target_domain: str, recursion_depth: int,
enabled_providers: List[str]) -> None:
"""Log the start of a reconnaissance scan."""
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
self.session_metadata['target_domains'].update(target_domain)
try:
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
with self.lock:
self.session_metadata['target_domains'].add(target_domain)
except Exception:
pass
def log_scan_complete(self) -> None:
"""Log the completion of a reconnaissance scan."""
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
with self.lock:
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
# Convert sets to lists for serialization
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
self.logger.info(f"Scan Complete - Session: {self.session_id}")
try:
self.logger.info(f"Scan Complete - Session: {self.session_id}")
except Exception:
pass
def export_audit_trail(self) -> Dict[str, Any]:
"""
@@ -214,12 +264,13 @@ class ForensicLogger:
Returns:
Dictionary containing complete session audit trail
"""
return {
'session_metadata': self.session_metadata.copy(),
'api_requests': [asdict(req) for req in self.api_requests],
'relationships': [asdict(rel) for rel in self.relationships],
'export_timestamp': datetime.now(timezone.utc).isoformat()
}
with self.lock:
return {
'session_metadata': self.session_metadata.copy(),
'api_requests': [asdict(req) for req in self.api_requests],
'relationships': [asdict(rel) for rel in self.relationships],
'export_timestamp': datetime.now(timezone.utc).isoformat()
}
def get_forensic_summary(self) -> Dict[str, Any]:
"""
@@ -229,7 +280,13 @@ class ForensicLogger:
Dictionary containing summary statistics
"""
provider_stats = {}
for provider in self.session_metadata['providers_used']:
# Ensure providers_used is a set for iteration
providers_used = self.session_metadata['providers_used']
if isinstance(providers_used, list):
providers_used = set(providers_used)
for provider in providers_used:
provider_requests = [req for req in self.api_requests if req.provider == provider]
provider_relationships = [rel for rel in self.relationships if rel.provider == provider]

View File

@@ -35,6 +35,7 @@ class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
UNIFIED: Combines comprehensive features with improved display formatting.
FIXED: Enhanced threading object initialization to prevent None references.
"""
def __init__(self, session_config=None, socketio=None):
@@ -44,6 +45,11 @@ class Scanner:
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
# FIXED: Initialize all threading objects first
self._initialize_threading_objects()
# Set socketio (but will be set to None for storage)
self.socketio = socketio
self.config = session_config
@@ -53,17 +59,12 @@ class Scanner:
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event()
self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
self.initial_targets = set()
# Thread-safe processing tracking (from Document 1)
self.currently_processing = set()
self.processing_lock = threading.Lock()
# Display-friendly processing list (from Document 2)
self.currently_processing_display = []
@@ -81,9 +82,10 @@ class Scanner:
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Status logger thread with improved formatting
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
# Initialize collections that will be recreated during unpickling
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
# Initialize providers with session config
self._initialize_providers()
@@ -99,12 +101,24 @@ class Scanner:
traceback.print_exc()
raise
def _initialize_threading_objects(self):
"""
FIXED: Initialize all threading objects with proper error handling.
This method can be called during both __init__ and __setstate__.
"""
self.stop_event = threading.Event()
self.processing_lock = threading.Lock()
self.status_logger_stop_event = threading.Event()
self.status_logger_thread = None
def _is_stop_requested(self) -> bool:
"""
Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries.
FIXED: Added None check for stop_event.
"""
if self.stop_event.is_set():
# FIXED: Ensure stop_event exists before checking
if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set():
return True
if self.session_id:
@@ -112,16 +126,24 @@ class Scanner:
from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id)
except Exception as e:
# Fall back to local event
return self.stop_event.is_set()
# Fall back to local event if it exists
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
return self.stop_event.is_set()
# Final fallback
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
def _set_stop_signal(self) -> None:
"""
Set stop signal both locally and in Redis.
FIXED: Added None check for stop_event.
"""
self.stop_event.set()
# FIXED: Ensure stop_event exists before setting
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.set()
if self.session_id:
try:
@@ -162,17 +184,21 @@ class Scanner:
"""Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
self.stop_event = threading.Event()
# FIXED: Ensure all threading objects are properly initialized
self._initialize_threading_objects()
# Re-initialize other objects
self.scan_thread = None
self.executor = None
self.processing_lock = threading.Lock()
self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger()
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
self.socketio = None
# FIXED: Initialize socketio as None but preserve ability to set it
if not hasattr(self, 'socketio'):
self.socketio = None
# Initialize missing attributes with defaults
if not hasattr(self, 'providers') or not self.providers:
self._initialize_providers()
@@ -182,11 +208,36 @@ class Scanner:
if not hasattr(self, 'currently_processing_display'):
self.currently_processing_display = []
if not hasattr(self, 'target_retries'):
self.target_retries = defaultdict(int)
if not hasattr(self, 'scan_failed_due_to_retries'):
self.scan_failed_due_to_retries = False
if not hasattr(self, 'initial_targets'):
self.initial_targets = set()
# Ensure providers have stop events
if hasattr(self, 'providers'):
for provider in self.providers:
if hasattr(provider, 'set_stop_event'):
if hasattr(provider, 'set_stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
def _ensure_threading_objects_exist(self):
"""
FIXED: Utility method to ensure threading objects exist before use.
Call this before any method that might use threading objects.
"""
if not hasattr(self, 'stop_event') or self.stop_event is None:
print("WARNING: Threading objects not initialized, recreating...")
self._initialize_threading_objects()
if not hasattr(self, 'processing_lock') or self.processing_lock is None:
self.processing_lock = threading.Lock()
if not hasattr(self, 'task_queue') or self.task_queue is None:
self.task_queue = PriorityQueue()
def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration."""
self.providers = []
@@ -224,7 +275,9 @@ class Scanner:
print(f" Available: {is_available}")
if is_available:
provider.set_stop_event(self.stop_event)
# FIXED: Ensure stop_event exists before setting it
if hasattr(self, 'stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph)
self.providers.append(provider)
@@ -254,15 +307,25 @@ class Scanner:
BOLD = "\033[1m"
last_status_str = ""
while not self.status_logger_stop_event.is_set():
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
while not (hasattr(self, 'status_logger_stop_event') and
self.status_logger_stop_event and
self.status_logger_stop_event.is_set()):
try:
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
# FIXED: Check if processing_lock exists before using
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
else:
in_flight_tasks = list(getattr(self, 'currently_processing', []))
status_str = (
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | "
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
f"Skipped: {self.tasks_skipped} | "
@@ -290,22 +353,30 @@ class Scanner:
time.sleep(2)
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
"""
FIXED: Enhanced start_scan with proper threading object initialization and socketio management.
"""
# FIXED: Ensure threading objects exist before proceeding
self._ensure_threading_objects_exist()
if self.scan_thread and self.scan_thread.is_alive():
self.logger.logger.info("Stopping existing scan before starting new one")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# Clean up processing state
with self.processing_lock:
self.currently_processing.clear()
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# Clear task queue
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
if hasattr(self, 'task_queue') and self.task_queue:
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
# Shutdown executor
if self.executor:
@@ -322,14 +393,26 @@ class Scanner:
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
self.status = ScanStatus.IDLE
self.stop_event.clear()
# FIXED: Ensure stop_event exists before clearing
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.clear()
if self.session_id:
from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
# FIXED: Restore socketio connection if missing
if not hasattr(self, 'socketio') or not self.socketio:
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"✓ Restored socketio connection for scan start")
with self.processing_lock:
self.currently_processing.clear()
# FIXED: Safe cleanup with existence checks
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
@@ -397,7 +480,10 @@ class Scanner:
)
self.scan_thread.start()
self.status_logger_stop_event.clear()
# FIXED: Ensure status_logger_stop_event exists before clearing
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread(
target=self._status_logger_thread,
daemon=True,
@@ -451,6 +537,13 @@ class Scanner:
return 10 # Very low rate limit = very low priority
def _execute_scan(self, target: str, max_depth: int) -> None:
"""
FIXED: Enhanced execute_scan with proper threading object handling.
"""
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
update_counter = 0 # Track updates for throttling
last_update_time = time.time()
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
@@ -482,8 +575,13 @@ class Scanner:
print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested():
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
@@ -536,10 +634,23 @@ class Scanner:
continue
# Thread-safe processing state management
with self.processing_lock:
processing_key = (provider_name, target_item)
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested():
break
processing_key = (provider_name, target_item)
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
@@ -558,7 +669,12 @@ class Scanner:
if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
update_counter += 1
current_time = time.time()
if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0):
self._update_session_state()
last_update_time = current_time
update_counter = 0
if self._is_stop_requested():
break
@@ -603,9 +719,13 @@ class Scanner:
self.indicators_completed += 1
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
# FIXED: Safe processing lock usage for cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
# PHASE 2: Run correlations on all discovered nodes
if not self._is_stop_requested():
@@ -618,8 +738,9 @@ class Scanner:
self.logger.logger.error(f"Scan failed: {e}")
finally:
# Comprehensive cleanup (same as before)
with self.processing_lock:
self.currently_processing.clear()
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
while not self.task_queue.empty():
@@ -635,7 +756,9 @@ class Scanner:
else:
self.status = ScanStatus.COMPLETED
self.status_logger_stop_event.set()
# FIXED: Safe stop event handling
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0)
@@ -689,8 +812,13 @@ class Scanner:
while not self._is_stop_requested() and correlation_tasks:
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
# FIXED: Safe processing check
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
@@ -722,10 +850,23 @@ class Scanner:
correlation_tasks.remove(task_tuple)
continue
with self.processing_lock:
processing_key = (provider_name, target_item)
# FIXED: Safe processing management
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested():
break
processing_key = (provider_name, target_item)
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
@@ -754,51 +895,214 @@ class Scanner:
correlation_tasks.remove(task_tuple)
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
# Rest of the methods remain the same but with similar threading object safety checks...
# I'll include the key methods that need fixes:
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive graph data for real-time updates."""
try:
# FIXED: Safe processing state access
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
else:
currently_processing_count = len(getattr(self, 'currently_processing', []))
currently_processing_list = list(getattr(self, 'currently_processing', []))
# FIXED: Always include complete graph data for real-time updates
graph_data = self.get_graph_data()
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph': graph_data, # FIXED: Always include complete graph data
'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception as e:
traceback.print_exc()
return {
'status': 'error',
'message': 'Failed to get status',
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
def _update_session_state(self) -> None:
"""
FIXED: Update the scanner state in Redis and emit real-time WebSocket updates.
Enhanced with better error handling and socketio management.
"""
if self.session_id:
try:
# Get current status for WebSocket emission
current_status = self.get_scan_status()
# FIXED: Emit real-time update via WebSocket with better error handling
socketio_available = False
if hasattr(self, 'socketio') and self.socketio:
try:
print(f"📡 EMITTING WebSocket update: {current_status.get('status', 'unknown')} - "
f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, "
f"Edges: {len(current_status.get('graph', {}).get('edges', []))}")
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted successfully")
socketio_available = True
except Exception as ws_error:
print(f"⚠️ WebSocket emission failed: {ws_error}")
# Try to get socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print("🔄 Attempting to use registered socketio connection...")
registered_socketio.emit('scan_update', current_status)
self.socketio = registered_socketio # Update our reference
print("✅ WebSocket update emitted via registered connection")
socketio_available = True
else:
print("⚠️ No registered socketio connection found")
except Exception as fallback_error:
print(f"⚠️ Fallback socketio emission also failed: {fallback_error}")
else:
# Try to restore socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print(f"🔄 Restoring socketio connection for session {self.session_id}")
self.socketio = registered_socketio
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted via restored connection")
socketio_available = True
else:
print(f"⚠️ No socketio connection available for session {self.session_id}")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio connection: {restore_error}")
if not socketio_available:
print(f"⚠️ Real-time updates unavailable for session {self.session_id}")
# Update session in Redis for persistence (always do this)
try:
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception as redis_error:
print(f"⚠️ Failed to update session in Redis: {redis_error}")
except Exception as e:
print(f"⚠️ Error updating session state: {e}")
import traceback
traceback.print_exc()
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
"""
Manages the entire process for a given target and provider.
This version is generalized to handle all relationships dynamically.
FIXED: Manages the entire process for a given target and provider with enhanced real-time updates.
"""
if self._is_stop_requested():
return set(), set(), False
is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, target_type)
self._initialize_provider_states(target)
new_targets = set()
provider_successful = True
try:
provider_result = self._execute_provider_query(provider, target, is_ip)
if provider_result is None:
provider_successful = False
elif not self._is_stop_requested():
# Pass all relationships to be processed
discovered, is_large_entity = self._process_provider_result_unified(
target, provider, provider_result, depth
)
new_targets.update(discovered)
# FIXED: Emit real-time update after processing provider result
if discovered or provider_result.get_relationship_count() > 0:
# Ensure we have socketio connection for real-time updates
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio connection during provider processing")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio during provider processing: {restore_error}")
self._update_session_state()
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
return new_targets, set(), provider_successful
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
"""
The "worker" function that directly communicates with the provider to fetch data.
"""
"""The "worker" function that directly communicates with the provider to fetch data."""
provider_name = provider.get_name()
start_time = datetime.now(timezone.utc)
@@ -825,9 +1129,7 @@ class Scanner:
def _create_large_entity_from_result(self, source_node: str, provider_name: str,
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
"""
Creates a large entity node, tags all member nodes, and returns its ID and members.
"""
"""Creates a large entity node, tags all member nodes, and returns its ID and members."""
members = {rel.target_node for rel in provider_result.relationships
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
@@ -860,7 +1162,7 @@ class Scanner:
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
"""
Removes a node from a large entity, allowing it to be processed normally.
FIXED: Removes a node from a large entity with immediate real-time update.
"""
if not self.graph.graph.has_node(node_id):
return False
@@ -879,7 +1181,6 @@ class Scanner:
for provider in eligible_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
# Use current depth of the large entity if available, else 0
depth = 0
if self.graph.graph.has_node(large_entity_id):
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
@@ -890,6 +1191,19 @@ class Scanner:
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
self.total_tasks_ever_enqueued += 1
# FIXED: Emit real-time update after extraction with socketio management
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for node extraction update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for extraction: {restore_error}")
self._update_session_state()
return True
return False
@@ -897,8 +1211,7 @@ class Scanner:
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
"""
Process a unified ProviderResult object to update the graph.
This version dynamically re-routes edges to a large entity container.
FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates.
"""
provider_name = provider.get_name()
discovered_targets = set()
@@ -918,6 +1231,10 @@ class Scanner:
target, provider_name, provider_result, current_depth
)
# Track if we added anything significant
nodes_added = 0
edges_added = 0
for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested():
break
@@ -949,17 +1266,20 @@ class Scanner:
max_depth_reached = current_depth >= self.max_depth
# Add actual nodes to the graph (they might be hidden by the UI)
self.graph.add_node(source_node_id, source_type)
self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached})
if self.graph.add_node(source_node_id, source_type):
nodes_added += 1
if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}):
nodes_added += 1
# Add the visual edge to the graph
self.graph.add_edge(
if self.graph.add_edge(
visual_source, visual_target,
relationship.relationship_type,
relationship.confidence,
provider_name,
relationship.raw_data
)
):
edges_added += 1
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
if target_node_id not in large_entity_members:
@@ -987,88 +1307,32 @@ class Scanner:
if not self.graph.graph.has_node(node_id):
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
nodes_added += 1
else:
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
return discovered_targets, is_large_entity
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# FIXED: Emit real-time update if we added anything significant
if nodes_added > 0 or edges_added > 0:
print(f"🔄 Added {nodes_added} nodes, {edges_added} edges - triggering real-time update")
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
# Ensure we have socketio connection for immediate update
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for immediate update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for immediate update: {restore_error}")
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.
"""
if self.session_id:
try:
if self.socketio:
self.socketio.emit('scan_update', self.get_scan_status())
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception:
pass
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive processing information."""
try:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph': self.get_graph_data(),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize(),
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception:
traceback.print_exc()
return { 'status': 'error', 'message': 'Failed to get status' }
return discovered_targets, is_large_entity
def _initialize_provider_states(self, target: str) -> None:
"""
FIXED: Safer provider state initialization with error handling.
"""
"""FIXED: Safer provider state initialization with error handling."""
try:
if not self.graph.graph.has_node(target):
return
@@ -1081,11 +1345,8 @@ class Scanner:
except Exception as e:
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
"""
FIXED: Improved provider eligibility checking with better filtering.
"""
"""FIXED: Improved provider eligibility checking with better filtering."""
if dns_only:
return [p for p in self.providers if p.get_name() == 'dns']
@@ -1124,9 +1385,7 @@ class Scanner:
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""
FIXED: More robust check for already queried providers with proper error handling.
"""
"""FIXED: More robust check for already queried providers with proper error handling."""
try:
if not self.graph.graph.has_node(target):
return False
@@ -1145,9 +1404,7 @@ class Scanner:
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None:
"""
FIXED: More robust provider state updates with validation.
"""
"""FIXED: More robust provider state updates with validation."""
try:
if not self.graph.graph.has_node(target):
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
@@ -1174,7 +1431,8 @@ class Scanner:
}
# Update last modified time for forensic integrity
self.last_modified = datetime.now(timezone.utc).isoformat()
if hasattr(self, 'last_modified'):
self.last_modified = datetime.now(timezone.utc).isoformat()
except Exception as e:
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
@@ -1191,9 +1449,14 @@ class Scanner:
return 0.0
# Add small buffer for tasks still in queue to avoid showing 100% too early
queue_size = max(0, self.task_queue.qsize())
with self.processing_lock:
active_tasks = len(self.currently_processing)
queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0)
# FIXED: Safe processing count
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
active_tasks = len(self.currently_processing)
else:
active_tasks = len(getattr(self, 'currently_processing', []))
# Adjust total to account for remaining work
adjusted_total = max(self.total_tasks_ever_enqueued,
@@ -1210,12 +1473,13 @@ class Scanner:
return 0.0
def get_graph_data(self) -> Dict[str, Any]:
"""Get current graph data formatted for frontend visualization."""
graph_data = self.graph.get_graph_data()
graph_data['initial_targets'] = list(self.initial_targets)
return graph_data
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
"""Get comprehensive information about all available providers."""
info = {}
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):

View File

@@ -6,6 +6,7 @@ import uuid
import redis
import pickle
from typing import Dict, Optional, Any
import copy
from core.scanner import Scanner
from config import config
@@ -13,7 +14,7 @@ from config import config
class SessionManager:
"""
FIXED: Manages multiple scanner instances for concurrent user sessions using Redis.
Now more conservative about session creation to preserve API keys and configuration.
Enhanced to properly maintain WebSocket connections throughout scan lifecycle.
"""
def __init__(self, session_timeout_minutes: int = 0):
@@ -30,6 +31,9 @@ class SessionManager:
# FIXED: Add a creation lock to prevent race conditions
self.creation_lock = threading.Lock()
# Track active socketio connections per session
self.active_socketio_connections = {}
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
@@ -40,7 +44,7 @@ class SessionManager:
"""Prepare SessionManager for pickling."""
state = self.__dict__.copy()
# Exclude unpickleable attributes - Redis client and threading objects
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock']
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
@@ -53,6 +57,7 @@ class SessionManager:
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.lock = threading.Lock()
self.creation_lock = threading.Lock()
self.active_socketio_connections = {}
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
@@ -64,22 +69,70 @@ class SessionManager:
"""Generates the Redis key for a session's stop signal."""
return f"dnsrecon:stop:{session_id}"
def register_socketio_connection(self, session_id: str, socketio) -> None:
"""
FIXED: Register a socketio connection for a session.
This ensures the connection is maintained throughout the session lifecycle.
"""
with self.lock:
self.active_socketio_connections[session_id] = socketio
print(f"Registered socketio connection for session {session_id}")
def get_socketio_connection(self, session_id: str):
"""
FIXED: Get the active socketio connection for a session.
"""
with self.lock:
return self.active_socketio_connections.get(session_id)
def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner:
"""
FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects.
Now preserves socketio connection info for restoration.
"""
# Set the session ID on the scanner for cross-process stop signal management
scanner.session_id = session_id
# FIXED: Don't set socketio to None if we want to preserve real-time updates
# Instead, we'll restore it when loading the scanner
scanner.socketio = None
# Force cleanup of any threading objects that might cause issues
if hasattr(scanner, 'stop_event'):
scanner.stop_event = None
if hasattr(scanner, 'scan_thread'):
scanner.scan_thread = None
if hasattr(scanner, 'executor'):
scanner.executor = None
if hasattr(scanner, 'status_logger_thread'):
scanner.status_logger_thread = None
if hasattr(scanner, 'status_logger_stop_event'):
scanner.status_logger_stop_event = None
return scanner
def create_session(self, socketio=None) -> str:
"""
FIXED: Create a new user session with thread-safe creation to prevent duplicates.
FIXED: Create a new user session with enhanced WebSocket management.
"""
# FIXED: Use creation lock to prevent race conditions
with self.creation_lock:
session_id = str(uuid.uuid4())
print(f"=== CREATING SESSION {session_id} IN REDIS ===")
# FIXED: Register socketio connection first
if socketio:
self.register_socketio_connection(session_id, socketio)
try:
from core.session_config import create_session_config
session_config = create_session_config()
scanner_instance = Scanner(session_config=session_config, socketio=socketio)
# Set the session ID on the scanner for cross-process stop signal management
scanner_instance.session_id = session_id
# Create scanner WITHOUT socketio to avoid weakref issues
scanner_instance = Scanner(session_config=session_config, socketio=None)
# Prepare scanner for storage (removes problematic objects)
scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id)
session_data = {
'scanner': scanner_instance,
@@ -89,12 +142,24 @@ class SessionManager:
'status': 'active'
}
# Serialize the entire session data dictionary using pickle
serialized_data = pickle.dumps(session_data)
# Test serialization before storing to catch issues early
try:
test_serialization = pickle.dumps(session_data)
print(f"Session serialization test successful ({len(test_serialization)} bytes)")
except Exception as pickle_error:
print(f"PICKLE TEST FAILED: {pickle_error}")
# Try to identify the problematic object
for key, value in session_data.items():
try:
pickle.dumps(value)
print(f" {key}: OK")
except Exception as item_error:
print(f" {key}: FAILED - {item_error}")
raise pickle_error
# Store in Redis
session_key = self._get_session_key(session_id)
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
self.redis_client.setex(session_key, self.session_timeout, test_serialization)
# Initialize stop signal as False
stop_key = self._get_stop_signal_key(session_id)
@@ -106,6 +171,8 @@ class SessionManager:
except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}")
import traceback
traceback.print_exc()
raise
def set_stop_signal(self, session_id: str) -> bool:
@@ -175,31 +242,63 @@ class SessionManager:
# Ensure the scanner has the correct session ID for stop signal checking
if 'scanner' in session_data and session_data['scanner']:
session_data['scanner'].session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
session_data['scanner'].socketio = socketio_conn
print(f"Restored socketio connection for session {session_id}")
else:
print(f"No socketio connection found for session {session_id}")
session_data['scanner'].socketio = None
return session_data
return None
except Exception as e:
print(f"ERROR: Failed to get session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return None
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
"""
Serializes and saves session data back to Redis with updated TTL.
FIXED: Now preserves socketio connection during storage.
Returns:
bool: True if save was successful
"""
try:
session_key = self._get_session_key(session_id)
serialized_data = pickle.dumps(session_data)
# Create a deep copy to avoid modifying the original scanner object
session_data_to_save = copy.deepcopy(session_data)
# Prepare scanner for storage if it exists
if 'scanner' in session_data_to_save and session_data_to_save['scanner']:
# FIXED: Preserve the original socketio connection before preparing for storage
original_socketio = session_data_to_save['scanner'].socketio
session_data_to_save['scanner'] = self._prepare_scanner_for_storage(
session_data_to_save['scanner'],
session_id
)
# FIXED: If we had a socketio connection, make sure it's registered
if original_socketio and session_id not in self.active_socketio_connections:
self.register_socketio_connection(session_id, original_socketio)
serialized_data = pickle.dumps(session_data_to_save)
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
return result
except Exception as e:
print(f"ERROR: Failed to save session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
"""
Updates just the scanner object in a session with immediate persistence.
FIXED: Updates just the scanner object in a session with immediate persistence.
Now maintains socketio connection throughout the update process.
Returns:
bool: True if update was successful
@@ -207,21 +306,27 @@ class SessionManager:
try:
session_data = self._get_session_data(session_id)
if session_data:
# Ensure scanner has the session ID
scanner.session_id = session_id
# FIXED: Preserve socketio connection before preparing for storage
original_socketio = scanner.socketio
# Prepare scanner for storage
scanner = self._prepare_scanner_for_storage(scanner, session_id)
session_data['scanner'] = scanner
session_data['last_activity'] = time.time()
# FIXED: Restore socketio connection after preparation
if original_socketio:
self.register_socketio_connection(session_id, original_socketio)
session_data['scanner'].socketio = original_socketio
# Immediately save to Redis for GUI updates
success = self._save_session_data(session_id, session_data)
if success:
# Only log occasionally to reduce noise
if hasattr(self, '_last_update_log'):
if time.time() - self._last_update_log > 5: # Log every 5 seconds max
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
self._last_update_log = time.time()
else:
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
self._last_update_log = time.time()
else:
print(f"WARNING: Failed to save scanner state for session {session_id}")
@@ -231,6 +336,8 @@ class SessionManager:
return False
except Exception as e:
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def update_scanner_status(self, session_id: str, status: str) -> bool:
@@ -263,7 +370,7 @@ class SessionManager:
def get_session(self, session_id: str) -> Optional[Scanner]:
"""
Get scanner instance for a session from Redis with session ID management.
FIXED: Get scanner instance for a session from Redis with proper socketio restoration.
"""
if not session_id:
return None
@@ -281,6 +388,15 @@ class SessionManager:
if scanner:
# Ensure the scanner can check the Redis-based stop signal
scanner.session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
scanner.socketio = socketio_conn
print(f"✓ Restored socketio connection for session {session_id}")
else:
scanner.socketio = None
print(f"⚠️ No socketio connection found for session {session_id}")
return scanner
@@ -333,6 +449,12 @@ class SessionManager:
# Wait a moment for graceful shutdown
time.sleep(0.5)
# FIXED: Clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up socketio connection for session {session_id}")
# Delete session data and stop signal from Redis
session_key = self._get_session_key(session_id)
stop_key = self._get_stop_signal_key(session_id)
@@ -344,6 +466,8 @@ class SessionManager:
except Exception as e:
print(f"ERROR: Failed to terminate session {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def _cleanup_loop(self) -> None:
@@ -364,6 +488,12 @@ class SessionManager:
self.redis_client.delete(stop_key)
print(f"Cleaned up orphaned stop signal for session {session_id}")
# Also clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up orphaned socketio for session {session_id}")
except Exception as e:
print(f"Error in cleanup loop: {e}")
@@ -387,14 +517,16 @@ class SessionManager:
return {
'total_active_sessions': active_sessions,
'running_scans': running_scans,
'total_stop_signals': len(stop_keys)
'total_stop_signals': len(stop_keys),
'active_socketio_connections': len(self.active_socketio_connections)
}
except Exception as e:
print(f"ERROR: Failed to get statistics: {e}")
return {
'total_active_sessions': 0,
'running_scans': 0,
'total_stop_signals': 0
'total_stop_signals': 0,
'active_socketio_connections': 0
}
# Global session manager instance