it
This commit is contained in:
326
core/scanner.py
326
core/scanner.py
@@ -2,8 +2,9 @@
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
import time
|
||||
from typing import List, Set, Dict, Any, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -27,7 +28,7 @@ class ScanStatus:
|
||||
class Scanner:
|
||||
"""
|
||||
Main scanning orchestrator for DNSRecon passive reconnaissance.
|
||||
REFACTORED: Simplified recursion with forensic provider state tracking.
|
||||
Enhanced with reliable cross-process termination capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, session_config=None):
|
||||
@@ -49,6 +50,7 @@ class Scanner:
|
||||
self.max_depth = 2
|
||||
self.stop_event = threading.Event()
|
||||
self.scan_thread = None
|
||||
self.session_id = None # Will be set by session manager
|
||||
|
||||
# Scanning progress tracking
|
||||
self.total_indicators_found = 0
|
||||
@@ -82,6 +84,42 @@ class Scanner:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _is_stop_requested(self) -> bool:
|
||||
"""
|
||||
Check if stop is requested using both local and Redis-based signals.
|
||||
This ensures reliable termination across process boundaries.
|
||||
"""
|
||||
# Check local threading event first (fastest)
|
||||
if self.stop_event.is_set():
|
||||
return True
|
||||
|
||||
# Check Redis-based stop signal if session ID is available
|
||||
if self.session_id:
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
return session_manager.is_stop_requested(self.session_id)
|
||||
except Exception as e:
|
||||
print(f"Error checking Redis stop signal: {e}")
|
||||
# Fall back to local event
|
||||
return self.stop_event.is_set()
|
||||
|
||||
return False
|
||||
|
||||
def _set_stop_signal(self) -> None:
|
||||
"""
|
||||
Set stop signal both locally and in Redis.
|
||||
"""
|
||||
# Set local event
|
||||
self.stop_event.set()
|
||||
|
||||
# Set Redis signal if session ID is available
|
||||
if self.session_id:
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.set_stop_signal(self.session_id)
|
||||
except Exception as e:
|
||||
print(f"Error setting Redis stop signal: {e}")
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare object for pickling by excluding unpicklable attributes."""
|
||||
state = self.__dict__.copy()
|
||||
@@ -159,8 +197,9 @@ class Scanner:
|
||||
print("Session configuration updated")
|
||||
|
||||
def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool:
|
||||
"""Start a new reconnaissance scan with forensic tracking."""
|
||||
"""Start a new reconnaissance scan with immediate GUI feedback."""
|
||||
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
|
||||
print(f"Session ID: {self.session_id}")
|
||||
print(f"Initial scanner status: {self.status}")
|
||||
|
||||
# Clean up previous scan thread if needed
|
||||
@@ -172,11 +211,13 @@ class Scanner:
|
||||
if self.scan_thread.is_alive():
|
||||
print("ERROR: The previous scan thread is unresponsive and could not be stopped.")
|
||||
self.status = ScanStatus.FAILED
|
||||
self._update_session_state()
|
||||
return False
|
||||
print("Previous scan thread terminated successfully.")
|
||||
|
||||
# Reset state for new scan
|
||||
self.status = ScanStatus.IDLE
|
||||
self._update_session_state() # Update GUI immediately
|
||||
print("Scanner state is now clean for a new scan.")
|
||||
|
||||
try:
|
||||
@@ -191,11 +232,20 @@ class Scanner:
|
||||
self.current_target = target_domain.lower().strip()
|
||||
self.max_depth = max_depth
|
||||
self.current_depth = 0
|
||||
|
||||
# Clear both local and Redis stop signals
|
||||
self.stop_event.clear()
|
||||
if self.session_id:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.clear_stop_signal(self.session_id)
|
||||
|
||||
self.total_indicators_found = 0
|
||||
self.indicators_processed = 0
|
||||
self.current_indicator = self.current_target
|
||||
|
||||
# Update GUI with scan preparation
|
||||
self._update_session_state()
|
||||
|
||||
# Start new forensic session
|
||||
print(f"Starting new forensic session for scanner {id(self)}...")
|
||||
self.logger = new_session()
|
||||
@@ -216,16 +266,20 @@ class Scanner:
|
||||
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
|
||||
traceback.print_exc()
|
||||
self.status = ScanStatus.FAILED
|
||||
self._update_session_state() # Update failed status immediately
|
||||
return False
|
||||
|
||||
def _execute_scan(self, target_domain: str, max_depth: int) -> None:
|
||||
"""Execute the reconnaissance scan with simplified recursion and forensic tracking."""
|
||||
"""Execute the reconnaissance scan with frequent state updates for GUI."""
|
||||
print(f"_execute_scan started for {target_domain} with depth {max_depth}")
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
processed_targets = set()
|
||||
|
||||
try:
|
||||
self.status = ScanStatus.RUNNING
|
||||
# Immediate status update for GUI
|
||||
self._update_session_state()
|
||||
|
||||
enabled_providers = [provider.get_name() for provider in self.providers]
|
||||
self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
|
||||
self.graph.add_node(target_domain, NodeType.DOMAIN)
|
||||
@@ -235,11 +289,13 @@ class Scanner:
|
||||
all_discovered_targets = {target_domain}
|
||||
|
||||
for depth in range(max_depth + 1):
|
||||
if self.stop_event.is_set():
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested at depth {depth}")
|
||||
break
|
||||
|
||||
self.current_depth = depth
|
||||
self._update_session_state()
|
||||
|
||||
targets_to_process = current_level_targets - processed_targets
|
||||
if not targets_to_process:
|
||||
print("No new targets to process at this level.")
|
||||
@@ -247,8 +303,9 @@ class Scanner:
|
||||
|
||||
print(f"Processing depth level {depth} with {len(targets_to_process)} new targets")
|
||||
self.total_indicators_found += len(targets_to_process)
|
||||
self._update_session_state()
|
||||
|
||||
target_results = self._process_targets_concurrent_forensic(
|
||||
target_results = self._process_targets_sequential_with_stop_checks(
|
||||
targets_to_process, processed_targets, all_discovered_targets, depth
|
||||
)
|
||||
processed_targets.update(targets_to_process)
|
||||
@@ -256,31 +313,57 @@ class Scanner:
|
||||
next_level_targets = set()
|
||||
for _target, new_targets in target_results:
|
||||
all_discovered_targets.update(new_targets)
|
||||
# This is the critical change: add all new targets to the next level
|
||||
next_level_targets.update(new_targets)
|
||||
|
||||
current_level_targets = next_level_targets
|
||||
# Filter out already processed targets before the next iteration
|
||||
current_level_targets = next_level_targets - processed_targets
|
||||
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested after processing depth {depth}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Scan execution failed with error: {e}")
|
||||
traceback.print_exc()
|
||||
self.status = ScanStatus.FAILED
|
||||
self._update_session_state() # Update failed status immediately
|
||||
self.logger.logger.error(f"Scan failed: {e}")
|
||||
finally:
|
||||
if self.stop_event.is_set():
|
||||
if self._is_stop_requested():
|
||||
self.status = ScanStatus.STOPPED
|
||||
else:
|
||||
self.status = ScanStatus.COMPLETED
|
||||
|
||||
# Final status update for GUI
|
||||
self._update_session_state()
|
||||
|
||||
self.logger.log_scan_complete()
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
stats = self.graph.get_statistics()
|
||||
print("Final scan statistics:")
|
||||
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
|
||||
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
|
||||
print(f" - Targets processed: {len(processed_targets)}")
|
||||
|
||||
def _update_session_state(self) -> None:
|
||||
"""
|
||||
Update the scanner state in Redis for GUI updates.
|
||||
This ensures the web interface sees real-time updates.
|
||||
"""
|
||||
if self.session_id:
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
success = session_manager.update_session_scanner(self.session_id, self)
|
||||
if not success:
|
||||
print(f"WARNING: Failed to update session state for {self.session_id}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update session state: {e}")
|
||||
|
||||
def _initialize_provider_states(self, target: str) -> None:
|
||||
"""Initialize provider states for forensic tracking."""
|
||||
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node()
|
||||
if not self.graph.graph.has_node(target):
|
||||
return
|
||||
|
||||
node_data = self.graph.graph.nodes[target]
|
||||
@@ -292,7 +375,6 @@ class Scanner:
|
||||
def _should_recurse_on_target(self, target: str, processed_targets: Set[str], all_discovered: Set[str]) -> bool:
|
||||
"""
|
||||
Simplified recursion logic: only recurse on valid IPs and domains that haven't been processed.
|
||||
FORENSIC PRINCIPLE: Clear, simple rules for what gets recursed.
|
||||
"""
|
||||
# Don't recurse on already processed targets
|
||||
if target in processed_targets:
|
||||
@@ -318,51 +400,129 @@ class Scanner:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_targets_concurrent_forensic(self, targets: Set[str], processed_targets: Set[str],
|
||||
all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]:
|
||||
"""Process multiple targets concurrently with forensic provider state tracking."""
|
||||
def _process_targets_sequential_with_stop_checks(self, targets: Set[str], processed_targets: Set[str],
|
||||
all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]:
|
||||
"""
|
||||
Process targets with controlled concurrency for both responsiveness and proper completion.
|
||||
Balances termination responsiveness with avoiding race conditions.
|
||||
"""
|
||||
results = []
|
||||
targets_to_process = targets - processed_targets
|
||||
if not targets_to_process:
|
||||
return results
|
||||
|
||||
print(f"Processing {len(targets_to_process)} targets concurrently with forensic tracking")
|
||||
print(f"Processing {len(targets_to_process)} targets with controlled concurrency")
|
||||
|
||||
future_to_target = {
|
||||
self.executor.submit(self._query_providers_forensic, target, current_depth): target
|
||||
for target in targets_to_process
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_target):
|
||||
if self.stop_event.is_set():
|
||||
future.cancel()
|
||||
continue
|
||||
target = future_to_target[future]
|
||||
try:
|
||||
new_targets = future.result()
|
||||
results.append((target, new_targets))
|
||||
self.indicators_processed += 1
|
||||
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
|
||||
except (Exception, CancelledError) as e:
|
||||
print(f"Error processing target {target}: {e}")
|
||||
self._log_target_processing_error(target, str(e))
|
||||
target_list = list(targets_to_process)
|
||||
active_futures: Dict[Future, str] = {}
|
||||
target_index = 0
|
||||
last_gui_update = time.time()
|
||||
|
||||
# Add this block to save the state to Redis
|
||||
from core.session_manager import session_manager
|
||||
if hasattr(self, 'user_session_id'):
|
||||
session_manager.update_session_scanner(self.user_session_id, self)
|
||||
while target_index < len(target_list) or active_futures:
|
||||
# Check stop signal before any new work
|
||||
if self._is_stop_requested():
|
||||
print("Stop requested - canceling active futures and exiting")
|
||||
for future in list(active_futures.keys()):
|
||||
future.cancel()
|
||||
break
|
||||
|
||||
# Submit new futures up to max_workers limit (controlled concurrency)
|
||||
while len(active_futures) < self.max_workers and target_index < len(target_list):
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
|
||||
target = target_list[target_index]
|
||||
self.current_indicator = target
|
||||
print(f"Submitting target {target_index + 1}/{len(target_list)}: {target}")
|
||||
|
||||
future = self.executor.submit(self._query_providers_forensic, target, current_depth)
|
||||
active_futures[future] = target
|
||||
target_index += 1
|
||||
|
||||
# Update GUI periodically
|
||||
current_time = time.time()
|
||||
if target_index % 2 == 0 or (current_time - last_gui_update) > 2.0:
|
||||
self._update_session_state()
|
||||
last_gui_update = current_time
|
||||
|
||||
# Wait for at least one future to complete (but don't wait forever)
|
||||
if active_futures:
|
||||
try:
|
||||
# Wait for the first completion with reasonable timeout
|
||||
completed_future = next(as_completed(active_futures.keys(), timeout=15.0))
|
||||
|
||||
target = active_futures[completed_future]
|
||||
try:
|
||||
new_targets = completed_future.result()
|
||||
results.append((target, new_targets))
|
||||
self.indicators_processed += 1
|
||||
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
|
||||
|
||||
# Update GUI after each completion
|
||||
current_time = time.time()
|
||||
if (current_time - last_gui_update) > 1.0:
|
||||
self._update_session_state()
|
||||
last_gui_update = current_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing target {target}: {e}")
|
||||
self._log_target_processing_error(target, str(e))
|
||||
|
||||
# Remove the completed future
|
||||
del active_futures[completed_future]
|
||||
|
||||
except StopIteration:
|
||||
# No futures completed within timeout - check stop signal and continue
|
||||
if self._is_stop_requested():
|
||||
print("Stop requested during timeout - canceling futures")
|
||||
for future in list(active_futures.keys()):
|
||||
future.cancel()
|
||||
break
|
||||
# Continue loop to wait for completions
|
||||
|
||||
except Exception as e:
|
||||
# as_completed timeout or other error
|
||||
if self._is_stop_requested():
|
||||
print("Stop requested during future waiting")
|
||||
for future in list(active_futures.keys()):
|
||||
future.cancel()
|
||||
break
|
||||
|
||||
# Check if any futures are actually done (in case of timeout exception)
|
||||
completed_futures = [f for f in active_futures.keys() if f.done()]
|
||||
for completed_future in completed_futures:
|
||||
target = active_futures[completed_future]
|
||||
try:
|
||||
new_targets = completed_future.result()
|
||||
results.append((target, new_targets))
|
||||
self.indicators_processed += 1
|
||||
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
|
||||
except Exception as ex:
|
||||
print(f"Error processing target {target}: {ex}")
|
||||
self._log_target_processing_error(target, str(ex))
|
||||
|
||||
del active_futures[completed_future]
|
||||
|
||||
print(f"Completed processing all targets at depth {current_depth}")
|
||||
|
||||
# Final state update
|
||||
self._update_session_state()
|
||||
|
||||
return results
|
||||
|
||||
def _query_providers_forensic(self, target: str, current_depth: int) -> Set[str]:
|
||||
"""
|
||||
Query providers for a target with forensic state tracking and simplified recursion.
|
||||
REFACTORED: Simplified logic with complete forensic audit trail.
|
||||
Query providers for a target with enhanced stop signal checking.
|
||||
"""
|
||||
is_ip = _is_valid_ip(target)
|
||||
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
||||
print(f"Querying providers for {target_type.value}: {target} at depth {current_depth}")
|
||||
|
||||
# Early stop check
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested before querying providers for {target}")
|
||||
return set()
|
||||
|
||||
# Initialize node and provider states
|
||||
self.graph.add_node(target, target_type)
|
||||
self._initialize_provider_states(target)
|
||||
@@ -377,34 +537,27 @@ class Scanner:
|
||||
self._log_no_eligible_providers(target, is_ip)
|
||||
return new_targets
|
||||
|
||||
# Query each eligible provider with forensic tracking
|
||||
with ThreadPoolExecutor(max_workers=len(eligible_providers)) as provider_executor:
|
||||
future_to_provider = {
|
||||
provider_executor.submit(self._query_single_provider_forensic, provider, target, is_ip, current_depth): provider
|
||||
for provider in eligible_providers
|
||||
}
|
||||
# Query each eligible provider sequentially with stop checks
|
||||
for provider in eligible_providers:
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested while querying providers for {target}")
|
||||
break
|
||||
|
||||
for future in as_completed(future_to_provider):
|
||||
if self.stop_event.is_set():
|
||||
future.cancel()
|
||||
continue
|
||||
|
||||
provider = future_to_provider[future]
|
||||
try:
|
||||
provider_results = future.result()
|
||||
if provider_results:
|
||||
discovered_targets = self._process_provider_results_forensic(
|
||||
target, provider, provider_results, target_metadata, current_depth
|
||||
)
|
||||
new_targets.update(discovered_targets)
|
||||
except (Exception, CancelledError) as e:
|
||||
self._log_provider_error(target, provider.get_name(), str(e))
|
||||
try:
|
||||
provider_results = self._query_single_provider_forensic(provider, target, is_ip, current_depth)
|
||||
if provider_results and not self._is_stop_requested():
|
||||
discovered_targets = self._process_provider_results_forensic(
|
||||
target, provider, provider_results, target_metadata, current_depth
|
||||
)
|
||||
new_targets.update(discovered_targets)
|
||||
except Exception as e:
|
||||
self._log_provider_error(target, provider.get_name(), str(e))
|
||||
|
||||
# Update node metadata
|
||||
for node_id, metadata_dict in target_metadata.items():
|
||||
if self.graph.graph.has_node(node_id):
|
||||
node_is_ip = _is_valid_ip(node_id)
|
||||
node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN
|
||||
# This call updates the existing node with the new metadata
|
||||
self.graph.add_node(node_id, node_type_to_add, metadata=metadata_dict)
|
||||
|
||||
return new_targets
|
||||
@@ -428,7 +581,7 @@ class Scanner:
|
||||
|
||||
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
|
||||
"""Check if we already queried a provider for a target."""
|
||||
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node()
|
||||
if not self.graph.graph.has_node(target):
|
||||
return False
|
||||
|
||||
node_data = self.graph.graph.nodes[target]
|
||||
@@ -436,10 +589,15 @@ class Scanner:
|
||||
return provider_name in provider_states
|
||||
|
||||
def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List:
|
||||
"""Query a single provider with complete forensic logging."""
|
||||
"""Query a single provider with stop signal checking."""
|
||||
provider_name = provider.get_name()
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Check stop signal before querying
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested before querying {provider_name} for {target}")
|
||||
return []
|
||||
|
||||
print(f"Querying {provider_name} for {target}")
|
||||
|
||||
# Log attempt
|
||||
@@ -452,6 +610,11 @@ class Scanner:
|
||||
else:
|
||||
results = provider.query_domain(target)
|
||||
|
||||
# Check stop signal after querying
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested after querying {provider_name} for {target}")
|
||||
return []
|
||||
|
||||
# Track successful state
|
||||
self._update_provider_state(target, provider_name, 'success', len(results), None, start_time)
|
||||
|
||||
@@ -467,7 +630,7 @@ class Scanner:
|
||||
def _update_provider_state(self, target: str, provider_name: str, status: str,
|
||||
results_count: int, error: str, start_time: datetime) -> None:
|
||||
"""Update provider state in node metadata for forensic tracking."""
|
||||
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node()
|
||||
if not self.graph.graph.has_node(target):
|
||||
return
|
||||
|
||||
node_data = self.graph.graph.nodes[target]
|
||||
@@ -489,10 +652,15 @@ class Scanner:
|
||||
|
||||
def _process_provider_results_forensic(self, target: str, provider, results: List,
|
||||
target_metadata: Dict, current_depth: int) -> Set[str]:
|
||||
"""Process provider results with large entity detection and forensic logging."""
|
||||
"""Process provider results with large entity detection and stop signal checking."""
|
||||
provider_name = provider.get_name()
|
||||
discovered_targets = set()
|
||||
|
||||
# Check for stop signal before processing results
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested before processing results from {provider_name} for {target}")
|
||||
return discovered_targets
|
||||
|
||||
# Check for large entity threshold per provider
|
||||
if len(results) > self.config.large_entity_threshold:
|
||||
print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}")
|
||||
@@ -503,8 +671,10 @@ class Scanner:
|
||||
# Process each relationship
|
||||
dns_records_to_create = {}
|
||||
|
||||
for source, rel_target, rel_type, confidence, raw_data in results:
|
||||
if self.stop_event.is_set():
|
||||
for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results):
|
||||
# Check stop signal periodically during result processing
|
||||
if i % 10 == 0 and self._is_stop_requested():
|
||||
print(f"Stop requested while processing results from {provider_name} for {target}")
|
||||
break
|
||||
|
||||
# Enhanced forensic logging for each relationship
|
||||
@@ -539,7 +709,7 @@ class Scanner:
|
||||
print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})")
|
||||
discovered_targets.add(rel_target)
|
||||
|
||||
# *** NEW: Enrich the newly discovered domain ***
|
||||
# Enrich the newly discovered domain
|
||||
self._collect_node_metadata_forensic(rel_target, provider_name, rel_type, source, raw_data, target_metadata[rel_target])
|
||||
|
||||
else:
|
||||
@@ -691,25 +861,24 @@ class Scanner:
|
||||
self.logger.logger.warning(f"No eligible providers for {target_type}: {target}")
|
||||
|
||||
def stop_scan(self) -> bool:
|
||||
"""Request immediate scan termination with forensic logging."""
|
||||
"""Request immediate scan termination with immediate GUI feedback."""
|
||||
try:
|
||||
if not self.scan_thread or not self.scan_thread.is_alive():
|
||||
print("No active scan thread to stop.")
|
||||
if self.status == ScanStatus.RUNNING:
|
||||
self.status = ScanStatus.STOPPED
|
||||
return False
|
||||
|
||||
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
|
||||
self.logger.logger.info("Scan termination requested by user")
|
||||
|
||||
# Set both local and Redis stop signals
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
self.stop_event.set()
|
||||
|
||||
# Immediately update GUI with stopped status
|
||||
self._update_session_state()
|
||||
|
||||
# Cancel executor futures if running
|
||||
if self.executor:
|
||||
print("Shutting down executor with immediate cancellation...")
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
print("Termination signal sent. The scan thread will stop shortly.")
|
||||
print("Termination signals sent. The scan will stop as soon as possible.")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -774,7 +943,8 @@ class Scanner:
|
||||
'final_status': self.status,
|
||||
'total_indicators_processed': self.indicators_processed,
|
||||
'enabled_providers': list(provider_stats.keys()),
|
||||
'forensic_note': 'Refactored scanner with simplified recursion and forensic tracking'
|
||||
'session_id': self.session_id,
|
||||
'forensic_note': 'Enhanced scanner with reliable cross-process termination'
|
||||
},
|
||||
'graph_data': graph_data,
|
||||
'forensic_audit': audit_trail,
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
import uuid
|
||||
import redis
|
||||
import pickle
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, List
|
||||
|
||||
from core.scanner import Scanner
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.scanner import Scanner
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages multiple scanner instances for concurrent user sessions using Redis.
|
||||
This allows session state to be shared across multiple Gunicorn worker processes.
|
||||
Enhanced with reliable cross-process stop signal management and immediate state updates.
|
||||
"""
|
||||
|
||||
def __init__(self, session_timeout_minutes: int = 60):
|
||||
@@ -57,6 +57,10 @@ class SessionManager:
|
||||
"""Generates the Redis key for a session."""
|
||||
return f"dnsrecon:session:{session_id}"
|
||||
|
||||
def _get_stop_signal_key(self, session_id: str) -> str:
|
||||
"""Generates the Redis key for a session's stop signal."""
|
||||
return f"dnsrecon:stop:{session_id}"
|
||||
|
||||
def create_session(self) -> str:
|
||||
"""
|
||||
Create a new user session and store it in Redis.
|
||||
@@ -69,6 +73,9 @@ class SessionManager:
|
||||
session_config = create_session_config()
|
||||
scanner_instance = Scanner(session_config=session_config)
|
||||
|
||||
# Set the session ID on the scanner for cross-process stop signal management
|
||||
scanner_instance.session_id = session_id
|
||||
|
||||
session_data = {
|
||||
'scanner': scanner_instance,
|
||||
'config': session_config,
|
||||
@@ -84,38 +91,166 @@ class SessionManager:
|
||||
session_key = self._get_session_key(session_id)
|
||||
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
|
||||
print(f"Session {session_id} stored in Redis")
|
||||
# Initialize stop signal as False
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
self.redis_client.setex(stop_key, self.session_timeout, b'0')
|
||||
|
||||
print(f"Session {session_id} stored in Redis with stop signal initialized")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to create session {session_id}: {e}")
|
||||
raise
|
||||
|
||||
def set_stop_signal(self, session_id: str) -> bool:
|
||||
"""
|
||||
Set the stop signal for a session (cross-process safe).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
bool: True if signal was set successfully
|
||||
"""
|
||||
try:
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
# Set stop signal to '1' with the same TTL as the session
|
||||
self.redis_client.setex(stop_key, self.session_timeout, b'1')
|
||||
print(f"Stop signal set for session {session_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to set stop signal for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def is_stop_requested(self, session_id: str) -> bool:
|
||||
"""
|
||||
Check if stop is requested for a session (cross-process safe).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
bool: True if stop is requested
|
||||
"""
|
||||
try:
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
value = self.redis_client.get(stop_key)
|
||||
return value == b'1' if value is not None else False
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to check stop signal for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def clear_stop_signal(self, session_id: str) -> bool:
|
||||
"""
|
||||
Clear the stop signal for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
bool: True if signal was cleared successfully
|
||||
"""
|
||||
try:
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
self.redis_client.setex(stop_key, self.session_timeout, b'0')
|
||||
print(f"Stop signal cleared for session {session_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieves and deserializes session data from Redis."""
|
||||
session_key = self._get_session_key(session_id)
|
||||
serialized_data = self.redis_client.get(session_key)
|
||||
if serialized_data:
|
||||
return pickle.loads(serialized_data)
|
||||
return None
|
||||
try:
|
||||
session_key = self._get_session_key(session_id)
|
||||
serialized_data = self.redis_client.get(session_key)
|
||||
if serialized_data:
|
||||
session_data = pickle.loads(serialized_data)
|
||||
# Ensure the scanner has the correct session ID for stop signal checking
|
||||
if 'scanner' in session_data and session_data['scanner']:
|
||||
session_data['scanner'].session_id = session_id
|
||||
return session_data
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get session data for {session_id}: {e}")
|
||||
return None
|
||||
|
||||
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]):
|
||||
"""Serializes and saves session data back to Redis with updated TTL."""
|
||||
session_key = self._get_session_key(session_id)
|
||||
serialized_data = pickle.dumps(session_data)
|
||||
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Serializes and saves session data back to Redis with updated TTL.
|
||||
|
||||
Returns:
|
||||
bool: True if save was successful
|
||||
"""
|
||||
try:
|
||||
session_key = self._get_session_key(session_id)
|
||||
serialized_data = pickle.dumps(session_data)
|
||||
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to save session data for {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def update_session_scanner(self, session_id: str, scanner: 'Scanner'):
|
||||
"""Updates just the scanner object in a session."""
|
||||
session_data = self._get_session_data(session_id)
|
||||
if session_data:
|
||||
session_data['scanner'] = scanner
|
||||
# We don't need to update last_activity here, as that's for user interaction
|
||||
self._save_session_data(session_id, session_data)
|
||||
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
|
||||
"""
|
||||
Updates just the scanner object in a session with immediate persistence.
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
"""
|
||||
try:
|
||||
session_data = self._get_session_data(session_id)
|
||||
if session_data:
|
||||
# Ensure scanner has the session ID
|
||||
scanner.session_id = session_id
|
||||
session_data['scanner'] = scanner
|
||||
session_data['last_activity'] = time.time()
|
||||
|
||||
# Immediately save to Redis for GUI updates
|
||||
success = self._save_session_data(session_id, session_data)
|
||||
if success:
|
||||
print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
||||
else:
|
||||
print(f"WARNING: Failed to save scanner state for session {session_id}")
|
||||
return success
|
||||
else:
|
||||
print(f"WARNING: Session {session_id} not found for scanner update")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def update_scanner_status(self, session_id: str, status: str) -> bool:
|
||||
"""
|
||||
Quickly update just the scanner status for immediate GUI feedback.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
status: New scanner status
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
"""
|
||||
try:
|
||||
session_data = self._get_session_data(session_id)
|
||||
if session_data and 'scanner' in session_data:
|
||||
session_data['scanner'].status = status
|
||||
session_data['last_activity'] = time.time()
|
||||
|
||||
success = self._save_session_data(session_id, session_data)
|
||||
if success:
|
||||
print(f"Scanner status updated to '{status}' for session {session_id}")
|
||||
else:
|
||||
print(f"WARNING: Failed to save status update for session {session_id}")
|
||||
return success
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update scanner status for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Scanner]:
|
||||
"""
|
||||
Get scanner instance for a session from Redis.
|
||||
Get scanner instance for a session from Redis with enhanced session ID management.
|
||||
"""
|
||||
if not session_id:
|
||||
return None
|
||||
@@ -129,37 +264,151 @@ class SessionManager:
|
||||
session_data['last_activity'] = time.time()
|
||||
self._save_session_data(session_id, session_data)
|
||||
|
||||
return session_data.get('scanner')
|
||||
scanner = session_data.get('scanner')
|
||||
if scanner:
|
||||
# Ensure the scanner can check the Redis-based stop signal
|
||||
scanner.session_id = session_id
|
||||
print(f"Retrieved scanner for session {session_id} (status: {scanner.status})")
|
||||
|
||||
return scanner
|
||||
|
||||
def get_session_status_only(self, session_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get just the scanner status without full session retrieval (for performance).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Scanner status string or None if not found
|
||||
"""
|
||||
try:
|
||||
session_data = self._get_session_data(session_id)
|
||||
if session_data and 'scanner' in session_data:
|
||||
return session_data['scanner'].status
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get session status for {session_id}: {e}")
|
||||
return None
|
||||
|
||||
def terminate_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Terminate a specific session in Redis.
|
||||
Terminate a specific session in Redis with reliable stop signal and immediate status update.
|
||||
"""
|
||||
session_data = self._get_session_data(session_id)
|
||||
if not session_data:
|
||||
return False
|
||||
print(f"=== TERMINATING SESSION {session_id} ===")
|
||||
|
||||
try:
|
||||
# First, set the stop signal
|
||||
self.set_stop_signal(session_id)
|
||||
|
||||
# Update scanner status to stopped immediately for GUI feedback
|
||||
self.update_scanner_status(session_id, 'stopped')
|
||||
|
||||
session_data = self._get_session_data(session_id)
|
||||
if not session_data:
|
||||
print(f"Session {session_id} not found")
|
||||
return False
|
||||
|
||||
scanner = session_data.get('scanner')
|
||||
if scanner and scanner.status == 'running':
|
||||
scanner.stop_scan()
|
||||
print(f"Stopped scan for session: {session_id}")
|
||||
|
||||
# Delete from Redis
|
||||
session_key = self._get_session_key(session_id)
|
||||
self.redis_client.delete(session_key)
|
||||
|
||||
print(f"Terminated and removed session from Redis: {session_id}")
|
||||
return True
|
||||
scanner = session_data.get('scanner')
|
||||
if scanner and scanner.status == 'running':
|
||||
print(f"Stopping scan for session: {session_id}")
|
||||
# The scanner will check the Redis stop signal
|
||||
scanner.stop_scan()
|
||||
|
||||
# Update the scanner state immediately
|
||||
self.update_session_scanner(session_id, scanner)
|
||||
|
||||
# Wait a moment for graceful shutdown
|
||||
time.sleep(0.5)
|
||||
|
||||
# Delete session data and stop signal from Redis
|
||||
session_key = self._get_session_key(session_id)
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
self.redis_client.delete(session_key)
|
||||
self.redis_client.delete(stop_key)
|
||||
|
||||
print(f"Terminated and removed session from Redis: {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to terminate session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def _cleanup_loop(self) -> None:
|
||||
"""
|
||||
Background thread to cleanup inactive sessions.
|
||||
Redis's TTL (setex) handles most of this automatically. This loop is a failsafe.
|
||||
Background thread to cleanup inactive sessions and orphaned stop signals.
|
||||
"""
|
||||
while True:
|
||||
# Redis handles expiration automatically, so this loop can be simplified or removed
|
||||
# For now, we'll keep it as a failsafe check for non-expiring keys if any get created by mistake
|
||||
time.sleep(300) # Sleep for 5 minutes
|
||||
try:
|
||||
# Clean up orphaned stop signals
|
||||
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
|
||||
for stop_key in stop_keys:
|
||||
# Extract session ID from stop key
|
||||
session_id = stop_key.decode('utf-8').split(':')[-1]
|
||||
session_key = self._get_session_key(session_id)
|
||||
|
||||
# If session doesn't exist but stop signal does, clean it up
|
||||
if not self.redis_client.exists(session_key):
|
||||
self.redis_client.delete(stop_key)
|
||||
print(f"Cleaned up orphaned stop signal for session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup loop: {e}")
|
||||
|
||||
time.sleep(300) # Sleep for 5 minutes
|
||||
|
||||
def list_active_sessions(self) -> List[Dict[str, Any]]:
|
||||
"""List all active sessions for admin purposes."""
|
||||
try:
|
||||
session_keys = self.redis_client.keys("dnsrecon:session:*")
|
||||
sessions = []
|
||||
|
||||
for session_key in session_keys:
|
||||
session_id = session_key.decode('utf-8').split(':')[-1]
|
||||
session_data = self._get_session_data(session_id)
|
||||
|
||||
if session_data:
|
||||
scanner = session_data.get('scanner')
|
||||
sessions.append({
|
||||
'session_id': session_id,
|
||||
'created_at': session_data.get('created_at'),
|
||||
'last_activity': session_data.get('last_activity'),
|
||||
'scanner_status': scanner.status if scanner else 'unknown',
|
||||
'current_target': scanner.current_target if scanner else None
|
||||
})
|
||||
|
||||
return sessions
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to list active sessions: {e}")
|
||||
return []
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get session manager statistics."""
|
||||
try:
|
||||
session_keys = self.redis_client.keys("dnsrecon:session:*")
|
||||
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
|
||||
|
||||
active_sessions = len(session_keys)
|
||||
running_scans = 0
|
||||
|
||||
for session_key in session_keys:
|
||||
session_id = session_key.decode('utf-8').split(':')[-1]
|
||||
status = self.get_session_status_only(session_id)
|
||||
if status == 'running':
|
||||
running_scans += 1
|
||||
|
||||
return {
|
||||
'total_active_sessions': active_sessions,
|
||||
'running_scans': running_scans,
|
||||
'total_stop_signals': len(stop_keys)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get statistics: {e}")
|
||||
return {
|
||||
'total_active_sessions': 0,
|
||||
'running_scans': 0,
|
||||
'total_stop_signals': 0
|
||||
}
|
||||
|
||||
# Global session manager instance
|
||||
session_manager = SessionManager(session_timeout_minutes=60)
|
||||
Reference in New Issue
Block a user