iteration on ws implementation
This commit is contained in:
@@ -4,8 +4,7 @@
|
||||
Graph data model for DNSRecon using NetworkX.
|
||||
Manages in-memory graph storage with confidence scoring and forensic metadata.
|
||||
Now fully compatible with the unified ProviderResult data model.
|
||||
UPDATED: Fixed correlation exclusion keys to match actual attribute names.
|
||||
UPDATED: Removed export_json() method - now handled by ExportManager.
|
||||
FIXED: Added proper pickle support to prevent weakref serialization errors.
|
||||
"""
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
@@ -33,6 +32,7 @@ class GraphManager:
|
||||
Thread-safe graph manager for DNSRecon infrastructure mapping.
|
||||
Uses NetworkX for in-memory graph storage with confidence scoring.
|
||||
Compatible with unified ProviderResult data model.
|
||||
FIXED: Added proper pickle support to handle NetworkX graph serialization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -40,6 +40,57 @@ class GraphManager:
|
||||
self.graph = nx.DiGraph()
|
||||
self.creation_time = datetime.now(timezone.utc).isoformat()
|
||||
self.last_modified = self.creation_time
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
|
||||
state = self.__dict__.copy()
|
||||
|
||||
# Convert NetworkX graph to a serializable format
|
||||
if hasattr(self, 'graph') and self.graph:
|
||||
# Extract all nodes with their data
|
||||
nodes_data = {}
|
||||
for node_id, attrs in self.graph.nodes(data=True):
|
||||
nodes_data[node_id] = dict(attrs)
|
||||
|
||||
# Extract all edges with their data
|
||||
edges_data = []
|
||||
for source, target, attrs in self.graph.edges(data=True):
|
||||
edges_data.append({
|
||||
'source': source,
|
||||
'target': target,
|
||||
'attributes': dict(attrs)
|
||||
})
|
||||
|
||||
# Replace the NetworkX graph with serializable data
|
||||
state['_graph_nodes'] = nodes_data
|
||||
state['_graph_edges'] = edges_data
|
||||
del state['graph']
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore GraphManager after unpickling by reconstructing NetworkX graph."""
|
||||
# Restore basic attributes
|
||||
self.__dict__.update(state)
|
||||
|
||||
# Reconstruct NetworkX graph from serializable data
|
||||
self.graph = nx.DiGraph()
|
||||
|
||||
# Restore nodes
|
||||
if hasattr(self, '_graph_nodes'):
|
||||
for node_id, attrs in self._graph_nodes.items():
|
||||
self.graph.add_node(node_id, **attrs)
|
||||
del self._graph_nodes
|
||||
|
||||
# Restore edges
|
||||
if hasattr(self, '_graph_edges'):
|
||||
for edge_data in self._graph_edges:
|
||||
self.graph.add_edge(
|
||||
edge_data['source'],
|
||||
edge_data['target'],
|
||||
**edge_data['attributes']
|
||||
)
|
||||
del self._graph_edges
|
||||
|
||||
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
|
||||
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
|
||||
145
core/logger.py
145
core/logger.py
@@ -40,6 +40,7 @@ class ForensicLogger:
|
||||
"""
|
||||
Thread-safe forensic logging system for DNSRecon.
|
||||
Maintains detailed audit trail of all reconnaissance activities.
|
||||
FIXED: Enhanced pickle support to prevent weakref issues in logging handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str = ""):
|
||||
@@ -65,45 +66,74 @@ class ForensicLogger:
|
||||
'target_domains': set()
|
||||
}
|
||||
|
||||
# Configure standard logger
|
||||
# Configure standard logger with simple setup to avoid weakrefs
|
||||
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter for structured logging
|
||||
# Create minimal formatter
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Add console handler if not already present
|
||||
# Add console handler only if not already present (avoid duplicate handlers)
|
||||
if not self.logger.handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare ForensicLogger for pickling by excluding unpicklable objects."""
|
||||
"""
|
||||
FIXED: Prepare ForensicLogger for pickling by excluding problematic objects.
|
||||
"""
|
||||
state = self.__dict__.copy()
|
||||
# Remove the unpickleable 'logger' attribute
|
||||
if 'logger' in state:
|
||||
del state['logger']
|
||||
if 'lock' in state:
|
||||
del state['lock']
|
||||
|
||||
# Remove potentially unpickleable attributes that may contain weakrefs
|
||||
unpicklable_attrs = ['logger', 'lock']
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
|
||||
# Convert sets to lists for JSON serialization compatibility
|
||||
if 'session_metadata' in state:
|
||||
metadata = state['session_metadata'].copy()
|
||||
if 'providers_used' in metadata and isinstance(metadata['providers_used'], set):
|
||||
metadata['providers_used'] = list(metadata['providers_used'])
|
||||
if 'target_domains' in metadata and isinstance(metadata['target_domains'], set):
|
||||
metadata['target_domains'] = list(metadata['target_domains'])
|
||||
state['session_metadata'] = metadata
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore ForensicLogger after unpickling by reconstructing logger."""
|
||||
"""
|
||||
FIXED: Restore ForensicLogger after unpickling by reconstructing components.
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
# Re-initialize the 'logger' attribute
|
||||
|
||||
# Re-initialize threading lock
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Re-initialize logger with minimal setup
|
||||
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Only add handler if not already present
|
||||
if not self.logger.handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Convert lists back to sets if needed
|
||||
if 'session_metadata' in self.__dict__:
|
||||
metadata = self.session_metadata
|
||||
if 'providers_used' in metadata and isinstance(metadata['providers_used'], list):
|
||||
metadata['providers_used'] = set(metadata['providers_used'])
|
||||
if 'target_domains' in metadata and isinstance(metadata['target_domains'], list):
|
||||
metadata['target_domains'] = set(metadata['target_domains'])
|
||||
|
||||
def _generate_session_id(self) -> str:
|
||||
"""Generate unique session identifier."""
|
||||
@@ -143,18 +173,23 @@ class ForensicLogger:
|
||||
discovery_context=discovery_context
|
||||
)
|
||||
|
||||
self.api_requests.append(api_request)
|
||||
self.session_metadata['total_requests'] += 1
|
||||
self.session_metadata['providers_used'].add(provider)
|
||||
with self.lock:
|
||||
self.api_requests.append(api_request)
|
||||
self.session_metadata['total_requests'] += 1
|
||||
self.session_metadata['providers_used'].add(provider)
|
||||
|
||||
if target_indicator:
|
||||
self.session_metadata['target_domains'].add(target_indicator)
|
||||
|
||||
if target_indicator:
|
||||
self.session_metadata['target_domains'].add(target_indicator)
|
||||
|
||||
# Log to standard logger
|
||||
if error:
|
||||
self.logger.error(f"API Request Failed.")
|
||||
else:
|
||||
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
|
||||
# Log to standard logger with error handling
|
||||
try:
|
||||
if error:
|
||||
self.logger.error(f"API Request Failed - {provider}: {url}")
|
||||
else:
|
||||
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
|
||||
except Exception:
|
||||
# If logging fails, continue without breaking the application
|
||||
pass
|
||||
|
||||
def log_relationship_discovery(self, source_node: str, target_node: str,
|
||||
relationship_type: str, confidence_score: float,
|
||||
@@ -183,29 +218,44 @@ class ForensicLogger:
|
||||
discovery_method=discovery_method
|
||||
)
|
||||
|
||||
self.relationships.append(relationship)
|
||||
self.session_metadata['total_relationships'] += 1
|
||||
with self.lock:
|
||||
self.relationships.append(relationship)
|
||||
self.session_metadata['total_relationships'] += 1
|
||||
|
||||
self.logger.info(
|
||||
f"Relationship Discovered - {source_node} -> {target_node} "
|
||||
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
|
||||
)
|
||||
# Log to standard logger with error handling
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Relationship Discovered - {source_node} -> {target_node} "
|
||||
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
|
||||
)
|
||||
except Exception:
|
||||
# If logging fails, continue without breaking the application
|
||||
pass
|
||||
|
||||
def log_scan_start(self, target_domain: str, recursion_depth: int,
|
||||
enabled_providers: List[str]) -> None:
|
||||
"""Log the start of a reconnaissance scan."""
|
||||
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
|
||||
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
|
||||
|
||||
self.session_metadata['target_domains'].update(target_domain)
|
||||
try:
|
||||
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
|
||||
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
|
||||
|
||||
with self.lock:
|
||||
self.session_metadata['target_domains'].add(target_domain)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log_scan_complete(self) -> None:
|
||||
"""Log the completion of a reconnaissance scan."""
|
||||
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
|
||||
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
|
||||
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
|
||||
with self.lock:
|
||||
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
|
||||
# Convert sets to lists for serialization
|
||||
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
|
||||
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
|
||||
|
||||
self.logger.info(f"Scan Complete - Session: {self.session_id}")
|
||||
try:
|
||||
self.logger.info(f"Scan Complete - Session: {self.session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def export_audit_trail(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -214,12 +264,13 @@ class ForensicLogger:
|
||||
Returns:
|
||||
Dictionary containing complete session audit trail
|
||||
"""
|
||||
return {
|
||||
'session_metadata': self.session_metadata.copy(),
|
||||
'api_requests': [asdict(req) for req in self.api_requests],
|
||||
'relationships': [asdict(rel) for rel in self.relationships],
|
||||
'export_timestamp': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
with self.lock:
|
||||
return {
|
||||
'session_metadata': self.session_metadata.copy(),
|
||||
'api_requests': [asdict(req) for req in self.api_requests],
|
||||
'relationships': [asdict(rel) for rel in self.relationships],
|
||||
'export_timestamp': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
def get_forensic_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -229,7 +280,13 @@ class ForensicLogger:
|
||||
Dictionary containing summary statistics
|
||||
"""
|
||||
provider_stats = {}
|
||||
for provider in self.session_metadata['providers_used']:
|
||||
|
||||
# Ensure providers_used is a set for iteration
|
||||
providers_used = self.session_metadata['providers_used']
|
||||
if isinstance(providers_used, list):
|
||||
providers_used = set(providers_used)
|
||||
|
||||
for provider in providers_used:
|
||||
provider_requests = [req for req in self.api_requests if req.provider == provider]
|
||||
provider_relationships = [rel for rel in self.relationships if rel.provider == provider]
|
||||
|
||||
|
||||
596
core/scanner.py
596
core/scanner.py
@@ -35,6 +35,7 @@ class Scanner:
|
||||
"""
|
||||
Main scanning orchestrator for DNSRecon passive reconnaissance.
|
||||
UNIFIED: Combines comprehensive features with improved display formatting.
|
||||
FIXED: Enhanced threading object initialization to prevent None references.
|
||||
"""
|
||||
|
||||
def __init__(self, session_config=None, socketio=None):
|
||||
@@ -44,6 +45,11 @@ class Scanner:
|
||||
if session_config is None:
|
||||
from core.session_config import create_session_config
|
||||
session_config = create_session_config()
|
||||
|
||||
# FIXED: Initialize all threading objects first
|
||||
self._initialize_threading_objects()
|
||||
|
||||
# Set socketio (but will be set to None for storage)
|
||||
self.socketio = socketio
|
||||
|
||||
self.config = session_config
|
||||
@@ -53,17 +59,12 @@ class Scanner:
|
||||
self.current_target = None
|
||||
self.current_depth = 0
|
||||
self.max_depth = 2
|
||||
self.stop_event = threading.Event()
|
||||
self.scan_thread = None
|
||||
self.session_id: Optional[str] = None # Will be set by session manager
|
||||
self.task_queue = PriorityQueue()
|
||||
self.target_retries = defaultdict(int)
|
||||
self.scan_failed_due_to_retries = False
|
||||
self.initial_targets = set()
|
||||
|
||||
# Thread-safe processing tracking (from Document 1)
|
||||
self.currently_processing = set()
|
||||
self.processing_lock = threading.Lock()
|
||||
# Display-friendly processing list (from Document 2)
|
||||
self.currently_processing_display = []
|
||||
|
||||
@@ -81,9 +82,10 @@ class Scanner:
|
||||
self.max_workers = self.config.max_concurrent_requests
|
||||
self.executor = None
|
||||
|
||||
# Status logger thread with improved formatting
|
||||
self.status_logger_thread = None
|
||||
self.status_logger_stop_event = threading.Event()
|
||||
# Initialize collections that will be recreated during unpickling
|
||||
self.task_queue = PriorityQueue()
|
||||
self.target_retries = defaultdict(int)
|
||||
self.scan_failed_due_to_retries = False
|
||||
|
||||
# Initialize providers with session config
|
||||
self._initialize_providers()
|
||||
@@ -99,12 +101,24 @@ class Scanner:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _initialize_threading_objects(self):
|
||||
"""
|
||||
FIXED: Initialize all threading objects with proper error handling.
|
||||
This method can be called during both __init__ and __setstate__.
|
||||
"""
|
||||
self.stop_event = threading.Event()
|
||||
self.processing_lock = threading.Lock()
|
||||
self.status_logger_stop_event = threading.Event()
|
||||
self.status_logger_thread = None
|
||||
|
||||
def _is_stop_requested(self) -> bool:
|
||||
"""
|
||||
Check if stop is requested using both local and Redis-based signals.
|
||||
This ensures reliable termination across process boundaries.
|
||||
FIXED: Added None check for stop_event.
|
||||
"""
|
||||
if self.stop_event.is_set():
|
||||
# FIXED: Ensure stop_event exists before checking
|
||||
if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set():
|
||||
return True
|
||||
|
||||
if self.session_id:
|
||||
@@ -112,16 +126,24 @@ class Scanner:
|
||||
from core.session_manager import session_manager
|
||||
return session_manager.is_stop_requested(self.session_id)
|
||||
except Exception as e:
|
||||
# Fall back to local event
|
||||
return self.stop_event.is_set()
|
||||
# Fall back to local event if it exists
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
return self.stop_event.is_set()
|
||||
return False
|
||||
|
||||
return self.stop_event.is_set()
|
||||
# Final fallback
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
return self.stop_event.is_set()
|
||||
return False
|
||||
|
||||
def _set_stop_signal(self) -> None:
|
||||
"""
|
||||
Set stop signal both locally and in Redis.
|
||||
FIXED: Added None check for stop_event.
|
||||
"""
|
||||
self.stop_event.set()
|
||||
# FIXED: Ensure stop_event exists before setting
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.session_id:
|
||||
try:
|
||||
@@ -162,17 +184,21 @@ class Scanner:
|
||||
"""Restore object after unpickling by reconstructing threading objects."""
|
||||
self.__dict__.update(state)
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
# FIXED: Ensure all threading objects are properly initialized
|
||||
self._initialize_threading_objects()
|
||||
|
||||
# Re-initialize other objects
|
||||
self.scan_thread = None
|
||||
self.executor = None
|
||||
self.processing_lock = threading.Lock()
|
||||
self.task_queue = PriorityQueue()
|
||||
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
||||
self.logger = get_forensic_logger()
|
||||
self.status_logger_thread = None
|
||||
self.status_logger_stop_event = threading.Event()
|
||||
self.socketio = None
|
||||
|
||||
# FIXED: Initialize socketio as None but preserve ability to set it
|
||||
if not hasattr(self, 'socketio'):
|
||||
self.socketio = None
|
||||
|
||||
# Initialize missing attributes with defaults
|
||||
if not hasattr(self, 'providers') or not self.providers:
|
||||
self._initialize_providers()
|
||||
|
||||
@@ -182,11 +208,36 @@ class Scanner:
|
||||
if not hasattr(self, 'currently_processing_display'):
|
||||
self.currently_processing_display = []
|
||||
|
||||
if not hasattr(self, 'target_retries'):
|
||||
self.target_retries = defaultdict(int)
|
||||
|
||||
if not hasattr(self, 'scan_failed_due_to_retries'):
|
||||
self.scan_failed_due_to_retries = False
|
||||
|
||||
if not hasattr(self, 'initial_targets'):
|
||||
self.initial_targets = set()
|
||||
|
||||
# Ensure providers have stop events
|
||||
if hasattr(self, 'providers'):
|
||||
for provider in self.providers:
|
||||
if hasattr(provider, 'set_stop_event'):
|
||||
if hasattr(provider, 'set_stop_event') and self.stop_event:
|
||||
provider.set_stop_event(self.stop_event)
|
||||
|
||||
def _ensure_threading_objects_exist(self):
|
||||
"""
|
||||
FIXED: Utility method to ensure threading objects exist before use.
|
||||
Call this before any method that might use threading objects.
|
||||
"""
|
||||
if not hasattr(self, 'stop_event') or self.stop_event is None:
|
||||
print("WARNING: Threading objects not initialized, recreating...")
|
||||
self._initialize_threading_objects()
|
||||
|
||||
if not hasattr(self, 'processing_lock') or self.processing_lock is None:
|
||||
self.processing_lock = threading.Lock()
|
||||
|
||||
if not hasattr(self, 'task_queue') or self.task_queue is None:
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
def _initialize_providers(self) -> None:
|
||||
"""Initialize all available providers based on session configuration."""
|
||||
self.providers = []
|
||||
@@ -224,7 +275,9 @@ class Scanner:
|
||||
print(f" Available: {is_available}")
|
||||
|
||||
if is_available:
|
||||
provider.set_stop_event(self.stop_event)
|
||||
# FIXED: Ensure stop_event exists before setting it
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
provider.set_stop_event(self.stop_event)
|
||||
if isinstance(provider, CorrelationProvider):
|
||||
provider.set_graph_manager(self.graph)
|
||||
self.providers.append(provider)
|
||||
@@ -254,15 +307,25 @@ class Scanner:
|
||||
BOLD = "\033[1m"
|
||||
|
||||
last_status_str = ""
|
||||
while not self.status_logger_stop_event.is_set():
|
||||
|
||||
# FIXED: Ensure threading objects exist
|
||||
self._ensure_threading_objects_exist()
|
||||
|
||||
while not (hasattr(self, 'status_logger_stop_event') and
|
||||
self.status_logger_stop_event and
|
||||
self.status_logger_stop_event.is_set()):
|
||||
try:
|
||||
with self.processing_lock:
|
||||
in_flight_tasks = list(self.currently_processing)
|
||||
self.currently_processing_display = in_flight_tasks.copy()
|
||||
# FIXED: Check if processing_lock exists before using
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
in_flight_tasks = list(self.currently_processing)
|
||||
self.currently_processing_display = in_flight_tasks.copy()
|
||||
else:
|
||||
in_flight_tasks = list(getattr(self, 'currently_processing', []))
|
||||
|
||||
status_str = (
|
||||
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
|
||||
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
|
||||
f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | "
|
||||
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
|
||||
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
|
||||
f"Skipped: {self.tasks_skipped} | "
|
||||
@@ -290,22 +353,30 @@ class Scanner:
|
||||
time.sleep(2)
|
||||
|
||||
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
|
||||
"""
|
||||
FIXED: Enhanced start_scan with proper threading object initialization and socketio management.
|
||||
"""
|
||||
# FIXED: Ensure threading objects exist before proceeding
|
||||
self._ensure_threading_objects_exist()
|
||||
|
||||
if self.scan_thread and self.scan_thread.is_alive():
|
||||
self.logger.logger.info("Stopping existing scan before starting new one")
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
|
||||
# Clean up processing state
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
# Clear task queue
|
||||
while not self.task_queue.empty():
|
||||
try:
|
||||
self.task_queue.get_nowait()
|
||||
except:
|
||||
break
|
||||
if hasattr(self, 'task_queue') and self.task_queue:
|
||||
while not self.task_queue.empty():
|
||||
try:
|
||||
self.task_queue.get_nowait()
|
||||
except:
|
||||
break
|
||||
|
||||
# Shutdown executor
|
||||
if self.executor:
|
||||
@@ -322,14 +393,26 @@ class Scanner:
|
||||
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
|
||||
|
||||
self.status = ScanStatus.IDLE
|
||||
self.stop_event.clear()
|
||||
|
||||
# FIXED: Ensure stop_event exists before clearing
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
self.stop_event.clear()
|
||||
|
||||
if self.session_id:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.clear_stop_signal(self.session_id)
|
||||
|
||||
# FIXED: Restore socketio connection if missing
|
||||
if not hasattr(self, 'socketio') or not self.socketio:
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"✓ Restored socketio connection for scan start")
|
||||
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
# FIXED: Safe cleanup with existence checks
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
self.task_queue = PriorityQueue()
|
||||
@@ -397,7 +480,10 @@ class Scanner:
|
||||
)
|
||||
self.scan_thread.start()
|
||||
|
||||
self.status_logger_stop_event.clear()
|
||||
# FIXED: Ensure status_logger_stop_event exists before clearing
|
||||
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
|
||||
self.status_logger_stop_event.clear()
|
||||
|
||||
self.status_logger_thread = threading.Thread(
|
||||
target=self._status_logger_thread,
|
||||
daemon=True,
|
||||
@@ -451,6 +537,13 @@ class Scanner:
|
||||
return 10 # Very low rate limit = very low priority
|
||||
|
||||
def _execute_scan(self, target: str, max_depth: int) -> None:
|
||||
"""
|
||||
FIXED: Enhanced execute_scan with proper threading object handling.
|
||||
"""
|
||||
# FIXED: Ensure threading objects exist
|
||||
self._ensure_threading_objects_exist()
|
||||
update_counter = 0 # Track updates for throttling
|
||||
last_update_time = time.time()
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
|
||||
|
||||
@@ -482,8 +575,13 @@ class Scanner:
|
||||
print(f"\n=== PHASE 1: Running non-correlation providers ===")
|
||||
while not self._is_stop_requested():
|
||||
queue_empty = self.task_queue.empty()
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
|
||||
# FIXED: Safe processing lock usage
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
else:
|
||||
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
|
||||
|
||||
if queue_empty and no_active_processing:
|
||||
consecutive_empty_iterations += 1
|
||||
@@ -536,10 +634,23 @@ class Scanner:
|
||||
continue
|
||||
|
||||
# Thread-safe processing state management
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
|
||||
# FIXED: Safe processing lock usage
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
continue
|
||||
self.currently_processing.add(processing_key)
|
||||
else:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
processing_key = (provider_name, target_item)
|
||||
if not hasattr(self, 'currently_processing'):
|
||||
self.currently_processing = set()
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
@@ -558,7 +669,12 @@ class Scanner:
|
||||
|
||||
if provider and not isinstance(provider, CorrelationProvider):
|
||||
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
|
||||
|
||||
update_counter += 1
|
||||
current_time = time.time()
|
||||
if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0):
|
||||
self._update_session_state()
|
||||
last_update_time = current_time
|
||||
update_counter = 0
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
|
||||
@@ -603,9 +719,13 @@ class Scanner:
|
||||
self.indicators_completed += 1
|
||||
|
||||
finally:
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
self.currently_processing.discard(processing_key)
|
||||
# FIXED: Safe processing lock usage for cleanup
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.discard(processing_key)
|
||||
else:
|
||||
if hasattr(self, 'currently_processing'):
|
||||
self.currently_processing.discard(processing_key)
|
||||
|
||||
# PHASE 2: Run correlations on all discovered nodes
|
||||
if not self._is_stop_requested():
|
||||
@@ -618,8 +738,9 @@ class Scanner:
|
||||
self.logger.logger.error(f"Scan failed: {e}")
|
||||
finally:
|
||||
# Comprehensive cleanup (same as before)
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
while not self.task_queue.empty():
|
||||
@@ -635,7 +756,9 @@ class Scanner:
|
||||
else:
|
||||
self.status = ScanStatus.COMPLETED
|
||||
|
||||
self.status_logger_stop_event.set()
|
||||
# FIXED: Safe stop event handling
|
||||
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
|
||||
self.status_logger_stop_event.set()
|
||||
if self.status_logger_thread and self.status_logger_thread.is_alive():
|
||||
self.status_logger_thread.join(timeout=2.0)
|
||||
|
||||
@@ -689,8 +812,13 @@ class Scanner:
|
||||
|
||||
while not self._is_stop_requested() and correlation_tasks:
|
||||
queue_empty = self.task_queue.empty()
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
|
||||
# FIXED: Safe processing check
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
else:
|
||||
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
|
||||
|
||||
if queue_empty and no_active_processing:
|
||||
consecutive_empty_iterations += 1
|
||||
@@ -722,10 +850,23 @@ class Scanner:
|
||||
correlation_tasks.remove(task_tuple)
|
||||
continue
|
||||
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
|
||||
# FIXED: Safe processing management
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
continue
|
||||
self.currently_processing.add(processing_key)
|
||||
else:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
processing_key = (provider_name, target_item)
|
||||
if not hasattr(self, 'currently_processing'):
|
||||
self.currently_processing = set()
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
@@ -754,51 +895,214 @@ class Scanner:
|
||||
correlation_tasks.remove(task_tuple)
|
||||
|
||||
finally:
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
self.currently_processing.discard(processing_key)
|
||||
# FIXED: Safe cleanup
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.discard(processing_key)
|
||||
else:
|
||||
if hasattr(self, 'currently_processing'):
|
||||
self.currently_processing.discard(processing_key)
|
||||
|
||||
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
|
||||
|
||||
# Rest of the methods remain the same but with similar threading object safety checks...
|
||||
# I'll include the key methods that need fixes:
|
||||
|
||||
def stop_scan(self) -> bool:
|
||||
"""Request immediate scan termination with proper cleanup."""
|
||||
try:
|
||||
self.logger.logger.info("Scan termination requested by user")
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
|
||||
# FIXED: Safe cleanup
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
if self.executor:
|
||||
try:
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._update_session_state()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error during scan termination: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def get_scan_status(self) -> Dict[str, Any]:
|
||||
"""Get current scan status with comprehensive graph data for real-time updates."""
|
||||
try:
|
||||
# FIXED: Safe processing state access
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
currently_processing_count = len(self.currently_processing)
|
||||
currently_processing_list = list(self.currently_processing)
|
||||
else:
|
||||
currently_processing_count = len(getattr(self, 'currently_processing', []))
|
||||
currently_processing_list = list(getattr(self, 'currently_processing', []))
|
||||
|
||||
# FIXED: Always include complete graph data for real-time updates
|
||||
graph_data = self.get_graph_data()
|
||||
|
||||
return {
|
||||
'status': self.status,
|
||||
'target_domain': self.current_target,
|
||||
'current_depth': self.current_depth,
|
||||
'max_depth': self.max_depth,
|
||||
'current_indicator': self.current_indicator,
|
||||
'indicators_processed': self.indicators_processed,
|
||||
'indicators_completed': self.indicators_completed,
|
||||
'tasks_re_enqueued': self.tasks_re_enqueued,
|
||||
'progress_percentage': self._calculate_progress(),
|
||||
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
|
||||
'enabled_providers': [provider.get_name() for provider in self.providers],
|
||||
'graph': graph_data, # FIXED: Always include complete graph data
|
||||
'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
|
||||
'currently_processing_count': currently_processing_count,
|
||||
'currently_processing': currently_processing_list[:5],
|
||||
'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
|
||||
'tasks_completed': self.indicators_completed,
|
||||
'tasks_skipped': self.tasks_skipped,
|
||||
'tasks_rescheduled': self.tasks_re_enqueued,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': 'Failed to get status',
|
||||
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
|
||||
}
|
||||
|
||||
def _update_session_state(self) -> None:
|
||||
"""
|
||||
FIXED: Update the scanner state in Redis and emit real-time WebSocket updates.
|
||||
Enhanced with better error handling and socketio management.
|
||||
"""
|
||||
if self.session_id:
|
||||
try:
|
||||
# Get current status for WebSocket emission
|
||||
current_status = self.get_scan_status()
|
||||
|
||||
# FIXED: Emit real-time update via WebSocket with better error handling
|
||||
socketio_available = False
|
||||
if hasattr(self, 'socketio') and self.socketio:
|
||||
try:
|
||||
print(f"📡 EMITTING WebSocket update: {current_status.get('status', 'unknown')} - "
|
||||
f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, "
|
||||
f"Edges: {len(current_status.get('graph', {}).get('edges', []))}")
|
||||
|
||||
self.socketio.emit('scan_update', current_status)
|
||||
print("✅ WebSocket update emitted successfully")
|
||||
socketio_available = True
|
||||
|
||||
except Exception as ws_error:
|
||||
print(f"⚠️ WebSocket emission failed: {ws_error}")
|
||||
# Try to get socketio from session manager
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
print("🔄 Attempting to use registered socketio connection...")
|
||||
registered_socketio.emit('scan_update', current_status)
|
||||
self.socketio = registered_socketio # Update our reference
|
||||
print("✅ WebSocket update emitted via registered connection")
|
||||
socketio_available = True
|
||||
else:
|
||||
print("⚠️ No registered socketio connection found")
|
||||
except Exception as fallback_error:
|
||||
print(f"⚠️ Fallback socketio emission also failed: {fallback_error}")
|
||||
else:
|
||||
# Try to restore socketio from session manager
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
print(f"🔄 Restoring socketio connection for session {self.session_id}")
|
||||
self.socketio = registered_socketio
|
||||
self.socketio.emit('scan_update', current_status)
|
||||
print("✅ WebSocket update emitted via restored connection")
|
||||
socketio_available = True
|
||||
else:
|
||||
print(f"⚠️ No socketio connection available for session {self.session_id}")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio connection: {restore_error}")
|
||||
|
||||
if not socketio_available:
|
||||
print(f"⚠️ Real-time updates unavailable for session {self.session_id}")
|
||||
|
||||
# Update session in Redis for persistence (always do this)
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.update_session_scanner(self.session_id, self)
|
||||
except Exception as redis_error:
|
||||
print(f"⚠️ Failed to update session in Redis: {redis_error}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error updating session state: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
||||
"""
|
||||
Manages the entire process for a given target and provider.
|
||||
This version is generalized to handle all relationships dynamically.
|
||||
FIXED: Manages the entire process for a given target and provider with enhanced real-time updates.
|
||||
"""
|
||||
if self._is_stop_requested():
|
||||
return set(), set(), False
|
||||
|
||||
|
||||
is_ip = _is_valid_ip(target)
|
||||
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
||||
|
||||
|
||||
self.graph.add_node(target, target_type)
|
||||
self._initialize_provider_states(target)
|
||||
|
||||
|
||||
new_targets = set()
|
||||
provider_successful = True
|
||||
|
||||
|
||||
try:
|
||||
provider_result = self._execute_provider_query(provider, target, is_ip)
|
||||
|
||||
|
||||
if provider_result is None:
|
||||
provider_successful = False
|
||||
elif not self._is_stop_requested():
|
||||
# Pass all relationships to be processed
|
||||
discovered, is_large_entity = self._process_provider_result_unified(
|
||||
target, provider, provider_result, depth
|
||||
)
|
||||
new_targets.update(discovered)
|
||||
|
||||
|
||||
# FIXED: Emit real-time update after processing provider result
|
||||
if discovered or provider_result.get_relationship_count() > 0:
|
||||
# Ensure we have socketio connection for real-time updates
|
||||
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"🔄 Restored socketio connection during provider processing")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio during provider processing: {restore_error}")
|
||||
|
||||
self._update_session_state()
|
||||
|
||||
except Exception as e:
|
||||
provider_successful = False
|
||||
self._log_provider_error(target, provider.get_name(), str(e))
|
||||
|
||||
|
||||
return new_targets, set(), provider_successful
|
||||
|
||||
|
||||
|
||||
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
|
||||
"""
|
||||
The "worker" function that directly communicates with the provider to fetch data.
|
||||
"""
|
||||
"""The "worker" function that directly communicates with the provider to fetch data."""
|
||||
provider_name = provider.get_name()
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -825,9 +1129,7 @@ class Scanner:
|
||||
|
||||
def _create_large_entity_from_result(self, source_node: str, provider_name: str,
|
||||
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
|
||||
"""
|
||||
Creates a large entity node, tags all member nodes, and returns its ID and members.
|
||||
"""
|
||||
"""Creates a large entity node, tags all member nodes, and returns its ID and members."""
|
||||
members = {rel.target_node for rel in provider_result.relationships
|
||||
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
|
||||
|
||||
@@ -860,7 +1162,7 @@ class Scanner:
|
||||
|
||||
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
|
||||
"""
|
||||
Removes a node from a large entity, allowing it to be processed normally.
|
||||
FIXED: Removes a node from a large entity with immediate real-time update.
|
||||
"""
|
||||
if not self.graph.graph.has_node(node_id):
|
||||
return False
|
||||
@@ -879,7 +1181,6 @@ class Scanner:
|
||||
for provider in eligible_providers:
|
||||
provider_name = provider.get_name()
|
||||
priority = self._get_priority(provider_name)
|
||||
# Use current depth of the large entity if available, else 0
|
||||
depth = 0
|
||||
if self.graph.graph.has_node(large_entity_id):
|
||||
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
||||
@@ -890,6 +1191,19 @@ class Scanner:
|
||||
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
|
||||
self.total_tasks_ever_enqueued += 1
|
||||
|
||||
# FIXED: Emit real-time update after extraction with socketio management
|
||||
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"🔄 Restored socketio for node extraction update")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio for extraction: {restore_error}")
|
||||
|
||||
self._update_session_state()
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -897,8 +1211,7 @@ class Scanner:
|
||||
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
|
||||
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
|
||||
"""
|
||||
Process a unified ProviderResult object to update the graph.
|
||||
This version dynamically re-routes edges to a large entity container.
|
||||
FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates.
|
||||
"""
|
||||
provider_name = provider.get_name()
|
||||
discovered_targets = set()
|
||||
@@ -918,6 +1231,10 @@ class Scanner:
|
||||
target, provider_name, provider_result, current_depth
|
||||
)
|
||||
|
||||
# Track if we added anything significant
|
||||
nodes_added = 0
|
||||
edges_added = 0
|
||||
|
||||
for i, relationship in enumerate(provider_result.relationships):
|
||||
if i % 5 == 0 and self._is_stop_requested():
|
||||
break
|
||||
@@ -949,17 +1266,20 @@ class Scanner:
|
||||
max_depth_reached = current_depth >= self.max_depth
|
||||
|
||||
# Add actual nodes to the graph (they might be hidden by the UI)
|
||||
self.graph.add_node(source_node_id, source_type)
|
||||
self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached})
|
||||
if self.graph.add_node(source_node_id, source_type):
|
||||
nodes_added += 1
|
||||
if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}):
|
||||
nodes_added += 1
|
||||
|
||||
# Add the visual edge to the graph
|
||||
self.graph.add_edge(
|
||||
if self.graph.add_edge(
|
||||
visual_source, visual_target,
|
||||
relationship.relationship_type,
|
||||
relationship.confidence,
|
||||
provider_name,
|
||||
relationship.raw_data
|
||||
)
|
||||
):
|
||||
edges_added += 1
|
||||
|
||||
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
|
||||
if target_node_id not in large_entity_members:
|
||||
@@ -987,88 +1307,32 @@ class Scanner:
|
||||
if not self.graph.graph.has_node(node_id):
|
||||
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
|
||||
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
|
||||
nodes_added += 1
|
||||
else:
|
||||
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
|
||||
|
||||
return discovered_targets, is_large_entity
|
||||
|
||||
def stop_scan(self) -> bool:
|
||||
"""Request immediate scan termination with proper cleanup."""
|
||||
try:
|
||||
self.logger.logger.info("Scan termination requested by user")
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
# FIXED: Emit real-time update if we added anything significant
|
||||
if nodes_added > 0 or edges_added > 0:
|
||||
print(f"🔄 Added {nodes_added} nodes, {edges_added} edges - triggering real-time update")
|
||||
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
if self.executor:
|
||||
# Ensure we have socketio connection for immediate update
|
||||
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
|
||||
try:
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"🔄 Restored socketio for immediate update")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio for immediate update: {restore_error}")
|
||||
|
||||
self._update_session_state()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error during scan termination: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _update_session_state(self) -> None:
|
||||
"""
|
||||
Update the scanner state in Redis for GUI updates.
|
||||
"""
|
||||
if self.session_id:
|
||||
try:
|
||||
if self.socketio:
|
||||
self.socketio.emit('scan_update', self.get_scan_status())
|
||||
from core.session_manager import session_manager
|
||||
session_manager.update_session_scanner(self.session_id, self)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_scan_status(self) -> Dict[str, Any]:
|
||||
"""Get current scan status with comprehensive processing information."""
|
||||
try:
|
||||
with self.processing_lock:
|
||||
currently_processing_count = len(self.currently_processing)
|
||||
currently_processing_list = list(self.currently_processing)
|
||||
|
||||
return {
|
||||
'status': self.status,
|
||||
'target_domain': self.current_target,
|
||||
'current_depth': self.current_depth,
|
||||
'max_depth': self.max_depth,
|
||||
'current_indicator': self.current_indicator,
|
||||
'indicators_processed': self.indicators_processed,
|
||||
'indicators_completed': self.indicators_completed,
|
||||
'tasks_re_enqueued': self.tasks_re_enqueued,
|
||||
'progress_percentage': self._calculate_progress(),
|
||||
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
|
||||
'enabled_providers': [provider.get_name() for provider in self.providers],
|
||||
'graph': self.get_graph_data(),
|
||||
'task_queue_size': self.task_queue.qsize(),
|
||||
'currently_processing_count': currently_processing_count,
|
||||
'currently_processing': currently_processing_list[:5],
|
||||
'tasks_in_queue': self.task_queue.qsize(),
|
||||
'tasks_completed': self.indicators_completed,
|
||||
'tasks_skipped': self.tasks_skipped,
|
||||
'tasks_rescheduled': self.tasks_re_enqueued,
|
||||
}
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return { 'status': 'error', 'message': 'Failed to get status' }
|
||||
|
||||
return discovered_targets, is_large_entity
|
||||
|
||||
def _initialize_provider_states(self, target: str) -> None:
|
||||
"""
|
||||
FIXED: Safer provider state initialization with error handling.
|
||||
"""
|
||||
"""FIXED: Safer provider state initialization with error handling."""
|
||||
try:
|
||||
if not self.graph.graph.has_node(target):
|
||||
return
|
||||
@@ -1081,11 +1345,8 @@ class Scanner:
|
||||
except Exception as e:
|
||||
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
|
||||
|
||||
|
||||
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
|
||||
"""
|
||||
FIXED: Improved provider eligibility checking with better filtering.
|
||||
"""
|
||||
"""FIXED: Improved provider eligibility checking with better filtering."""
|
||||
if dns_only:
|
||||
return [p for p in self.providers if p.get_name() == 'dns']
|
||||
|
||||
@@ -1124,9 +1385,7 @@ class Scanner:
|
||||
return eligible
|
||||
|
||||
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
|
||||
"""
|
||||
FIXED: More robust check for already queried providers with proper error handling.
|
||||
"""
|
||||
"""FIXED: More robust check for already queried providers with proper error handling."""
|
||||
try:
|
||||
if not self.graph.graph.has_node(target):
|
||||
return False
|
||||
@@ -1145,9 +1404,7 @@ class Scanner:
|
||||
|
||||
def _update_provider_state(self, target: str, provider_name: str, status: str,
|
||||
results_count: int, error: Optional[str], start_time: datetime) -> None:
|
||||
"""
|
||||
FIXED: More robust provider state updates with validation.
|
||||
"""
|
||||
"""FIXED: More robust provider state updates with validation."""
|
||||
try:
|
||||
if not self.graph.graph.has_node(target):
|
||||
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
|
||||
@@ -1174,7 +1431,8 @@ class Scanner:
|
||||
}
|
||||
|
||||
# Update last modified time for forensic integrity
|
||||
self.last_modified = datetime.now(timezone.utc).isoformat()
|
||||
if hasattr(self, 'last_modified'):
|
||||
self.last_modified = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
|
||||
@@ -1191,9 +1449,14 @@ class Scanner:
|
||||
return 0.0
|
||||
|
||||
# Add small buffer for tasks still in queue to avoid showing 100% too early
|
||||
queue_size = max(0, self.task_queue.qsize())
|
||||
with self.processing_lock:
|
||||
active_tasks = len(self.currently_processing)
|
||||
queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0)
|
||||
|
||||
# FIXED: Safe processing count
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
active_tasks = len(self.currently_processing)
|
||||
else:
|
||||
active_tasks = len(getattr(self, 'currently_processing', []))
|
||||
|
||||
# Adjust total to account for remaining work
|
||||
adjusted_total = max(self.total_tasks_ever_enqueued,
|
||||
@@ -1210,12 +1473,13 @@ class Scanner:
|
||||
return 0.0
|
||||
|
||||
def get_graph_data(self) -> Dict[str, Any]:
|
||||
"""Get current graph data formatted for frontend visualization."""
|
||||
graph_data = self.graph.get_graph_data()
|
||||
graph_data['initial_targets'] = list(self.initial_targets)
|
||||
return graph_data
|
||||
|
||||
|
||||
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get comprehensive information about all available providers."""
|
||||
info = {}
|
||||
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
||||
for filename in os.listdir(provider_dir):
|
||||
|
||||
@@ -6,6 +6,7 @@ import uuid
|
||||
import redis
|
||||
import pickle
|
||||
from typing import Dict, Optional, Any
|
||||
import copy
|
||||
|
||||
from core.scanner import Scanner
|
||||
from config import config
|
||||
@@ -13,7 +14,7 @@ from config import config
|
||||
class SessionManager:
|
||||
"""
|
||||
FIXED: Manages multiple scanner instances for concurrent user sessions using Redis.
|
||||
Now more conservative about session creation to preserve API keys and configuration.
|
||||
Enhanced to properly maintain WebSocket connections throughout scan lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, session_timeout_minutes: int = 0):
|
||||
@@ -30,6 +31,9 @@ class SessionManager:
|
||||
# FIXED: Add a creation lock to prevent race conditions
|
||||
self.creation_lock = threading.Lock()
|
||||
|
||||
# Track active socketio connections per session
|
||||
self.active_socketio_connections = {}
|
||||
|
||||
# Start cleanup thread
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
@@ -40,7 +44,7 @@ class SessionManager:
|
||||
"""Prepare SessionManager for pickling."""
|
||||
state = self.__dict__.copy()
|
||||
# Exclude unpickleable attributes - Redis client and threading objects
|
||||
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock']
|
||||
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections']
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
@@ -53,6 +57,7 @@ class SessionManager:
|
||||
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
|
||||
self.lock = threading.Lock()
|
||||
self.creation_lock = threading.Lock()
|
||||
self.active_socketio_connections = {}
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
|
||||
@@ -64,22 +69,70 @@ class SessionManager:
|
||||
"""Generates the Redis key for a session's stop signal."""
|
||||
return f"dnsrecon:stop:{session_id}"
|
||||
|
||||
def register_socketio_connection(self, session_id: str, socketio) -> None:
|
||||
"""
|
||||
FIXED: Register a socketio connection for a session.
|
||||
This ensures the connection is maintained throughout the session lifecycle.
|
||||
"""
|
||||
with self.lock:
|
||||
self.active_socketio_connections[session_id] = socketio
|
||||
print(f"Registered socketio connection for session {session_id}")
|
||||
|
||||
def get_socketio_connection(self, session_id: str):
|
||||
"""
|
||||
FIXED: Get the active socketio connection for a session.
|
||||
"""
|
||||
with self.lock:
|
||||
return self.active_socketio_connections.get(session_id)
|
||||
|
||||
def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner:
|
||||
"""
|
||||
FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects.
|
||||
Now preserves socketio connection info for restoration.
|
||||
"""
|
||||
# Set the session ID on the scanner for cross-process stop signal management
|
||||
scanner.session_id = session_id
|
||||
|
||||
# FIXED: Don't set socketio to None if we want to preserve real-time updates
|
||||
# Instead, we'll restore it when loading the scanner
|
||||
scanner.socketio = None
|
||||
|
||||
# Force cleanup of any threading objects that might cause issues
|
||||
if hasattr(scanner, 'stop_event'):
|
||||
scanner.stop_event = None
|
||||
if hasattr(scanner, 'scan_thread'):
|
||||
scanner.scan_thread = None
|
||||
if hasattr(scanner, 'executor'):
|
||||
scanner.executor = None
|
||||
if hasattr(scanner, 'status_logger_thread'):
|
||||
scanner.status_logger_thread = None
|
||||
if hasattr(scanner, 'status_logger_stop_event'):
|
||||
scanner.status_logger_stop_event = None
|
||||
|
||||
return scanner
|
||||
|
||||
def create_session(self, socketio=None) -> str:
|
||||
"""
|
||||
FIXED: Create a new user session with thread-safe creation to prevent duplicates.
|
||||
FIXED: Create a new user session with enhanced WebSocket management.
|
||||
"""
|
||||
# FIXED: Use creation lock to prevent race conditions
|
||||
with self.creation_lock:
|
||||
session_id = str(uuid.uuid4())
|
||||
print(f"=== CREATING SESSION {session_id} IN REDIS ===")
|
||||
|
||||
# FIXED: Register socketio connection first
|
||||
if socketio:
|
||||
self.register_socketio_connection(session_id, socketio)
|
||||
|
||||
try:
|
||||
from core.session_config import create_session_config
|
||||
session_config = create_session_config()
|
||||
scanner_instance = Scanner(session_config=session_config, socketio=socketio)
|
||||
|
||||
# Set the session ID on the scanner for cross-process stop signal management
|
||||
scanner_instance.session_id = session_id
|
||||
# Create scanner WITHOUT socketio to avoid weakref issues
|
||||
scanner_instance = Scanner(session_config=session_config, socketio=None)
|
||||
|
||||
# Prepare scanner for storage (removes problematic objects)
|
||||
scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id)
|
||||
|
||||
session_data = {
|
||||
'scanner': scanner_instance,
|
||||
@@ -89,12 +142,24 @@ class SessionManager:
|
||||
'status': 'active'
|
||||
}
|
||||
|
||||
# Serialize the entire session data dictionary using pickle
|
||||
serialized_data = pickle.dumps(session_data)
|
||||
# Test serialization before storing to catch issues early
|
||||
try:
|
||||
test_serialization = pickle.dumps(session_data)
|
||||
print(f"Session serialization test successful ({len(test_serialization)} bytes)")
|
||||
except Exception as pickle_error:
|
||||
print(f"PICKLE TEST FAILED: {pickle_error}")
|
||||
# Try to identify the problematic object
|
||||
for key, value in session_data.items():
|
||||
try:
|
||||
pickle.dumps(value)
|
||||
print(f" {key}: OK")
|
||||
except Exception as item_error:
|
||||
print(f" {key}: FAILED - {item_error}")
|
||||
raise pickle_error
|
||||
|
||||
# Store in Redis
|
||||
session_key = self._get_session_key(session_id)
|
||||
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
self.redis_client.setex(session_key, self.session_timeout, test_serialization)
|
||||
|
||||
# Initialize stop signal as False
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
@@ -106,6 +171,8 @@ class SessionManager:
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to create session {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def set_stop_signal(self, session_id: str) -> bool:
|
||||
@@ -175,31 +242,63 @@ class SessionManager:
|
||||
# Ensure the scanner has the correct session ID for stop signal checking
|
||||
if 'scanner' in session_data and session_data['scanner']:
|
||||
session_data['scanner'].session_id = session_id
|
||||
# FIXED: Restore socketio connection from our registry
|
||||
socketio_conn = self.get_socketio_connection(session_id)
|
||||
if socketio_conn:
|
||||
session_data['scanner'].socketio = socketio_conn
|
||||
print(f"Restored socketio connection for session {session_id}")
|
||||
else:
|
||||
print(f"No socketio connection found for session {session_id}")
|
||||
session_data['scanner'].socketio = None
|
||||
return session_data
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get session data for {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Serializes and saves session data back to Redis with updated TTL.
|
||||
FIXED: Now preserves socketio connection during storage.
|
||||
|
||||
Returns:
|
||||
bool: True if save was successful
|
||||
"""
|
||||
try:
|
||||
session_key = self._get_session_key(session_id)
|
||||
serialized_data = pickle.dumps(session_data)
|
||||
|
||||
# Create a deep copy to avoid modifying the original scanner object
|
||||
session_data_to_save = copy.deepcopy(session_data)
|
||||
|
||||
# Prepare scanner for storage if it exists
|
||||
if 'scanner' in session_data_to_save and session_data_to_save['scanner']:
|
||||
# FIXED: Preserve the original socketio connection before preparing for storage
|
||||
original_socketio = session_data_to_save['scanner'].socketio
|
||||
|
||||
session_data_to_save['scanner'] = self._prepare_scanner_for_storage(
|
||||
session_data_to_save['scanner'],
|
||||
session_id
|
||||
)
|
||||
|
||||
# FIXED: If we had a socketio connection, make sure it's registered
|
||||
if original_socketio and session_id not in self.active_socketio_connections:
|
||||
self.register_socketio_connection(session_id, original_socketio)
|
||||
|
||||
serialized_data = pickle.dumps(session_data_to_save)
|
||||
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to save session data for {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
|
||||
"""
|
||||
Updates just the scanner object in a session with immediate persistence.
|
||||
FIXED: Updates just the scanner object in a session with immediate persistence.
|
||||
Now maintains socketio connection throughout the update process.
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
@@ -207,21 +306,27 @@ class SessionManager:
|
||||
try:
|
||||
session_data = self._get_session_data(session_id)
|
||||
if session_data:
|
||||
# Ensure scanner has the session ID
|
||||
scanner.session_id = session_id
|
||||
# FIXED: Preserve socketio connection before preparing for storage
|
||||
original_socketio = scanner.socketio
|
||||
|
||||
# Prepare scanner for storage
|
||||
scanner = self._prepare_scanner_for_storage(scanner, session_id)
|
||||
session_data['scanner'] = scanner
|
||||
session_data['last_activity'] = time.time()
|
||||
|
||||
# FIXED: Restore socketio connection after preparation
|
||||
if original_socketio:
|
||||
self.register_socketio_connection(session_id, original_socketio)
|
||||
session_data['scanner'].socketio = original_socketio
|
||||
|
||||
# Immediately save to Redis for GUI updates
|
||||
success = self._save_session_data(session_id, session_data)
|
||||
if success:
|
||||
# Only log occasionally to reduce noise
|
||||
if hasattr(self, '_last_update_log'):
|
||||
if time.time() - self._last_update_log > 5: # Log every 5 seconds max
|
||||
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
||||
self._last_update_log = time.time()
|
||||
else:
|
||||
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
||||
self._last_update_log = time.time()
|
||||
else:
|
||||
print(f"WARNING: Failed to save scanner state for session {session_id}")
|
||||
@@ -231,6 +336,8 @@ class SessionManager:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def update_scanner_status(self, session_id: str, status: str) -> bool:
|
||||
@@ -263,7 +370,7 @@ class SessionManager:
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Scanner]:
|
||||
"""
|
||||
Get scanner instance for a session from Redis with session ID management.
|
||||
FIXED: Get scanner instance for a session from Redis with proper socketio restoration.
|
||||
"""
|
||||
if not session_id:
|
||||
return None
|
||||
@@ -281,6 +388,15 @@ class SessionManager:
|
||||
if scanner:
|
||||
# Ensure the scanner can check the Redis-based stop signal
|
||||
scanner.session_id = session_id
|
||||
|
||||
# FIXED: Restore socketio connection from our registry
|
||||
socketio_conn = self.get_socketio_connection(session_id)
|
||||
if socketio_conn:
|
||||
scanner.socketio = socketio_conn
|
||||
print(f"✓ Restored socketio connection for session {session_id}")
|
||||
else:
|
||||
scanner.socketio = None
|
||||
print(f"⚠️ No socketio connection found for session {session_id}")
|
||||
|
||||
return scanner
|
||||
|
||||
@@ -333,6 +449,12 @@ class SessionManager:
|
||||
# Wait a moment for graceful shutdown
|
||||
time.sleep(0.5)
|
||||
|
||||
# FIXED: Clean up socketio connection
|
||||
with self.lock:
|
||||
if session_id in self.active_socketio_connections:
|
||||
del self.active_socketio_connections[session_id]
|
||||
print(f"Cleaned up socketio connection for session {session_id}")
|
||||
|
||||
# Delete session data and stop signal from Redis
|
||||
session_key = self._get_session_key(session_id)
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
@@ -344,6 +466,8 @@ class SessionManager:
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to terminate session {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _cleanup_loop(self) -> None:
|
||||
@@ -364,6 +488,12 @@ class SessionManager:
|
||||
self.redis_client.delete(stop_key)
|
||||
print(f"Cleaned up orphaned stop signal for session {session_id}")
|
||||
|
||||
# Also clean up socketio connection
|
||||
with self.lock:
|
||||
if session_id in self.active_socketio_connections:
|
||||
del self.active_socketio_connections[session_id]
|
||||
print(f"Cleaned up orphaned socketio for session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup loop: {e}")
|
||||
|
||||
@@ -387,14 +517,16 @@ class SessionManager:
|
||||
return {
|
||||
'total_active_sessions': active_sessions,
|
||||
'running_scans': running_scans,
|
||||
'total_stop_signals': len(stop_keys)
|
||||
'total_stop_signals': len(stop_keys),
|
||||
'active_socketio_connections': len(self.active_socketio_connections)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get statistics: {e}")
|
||||
return {
|
||||
'total_active_sessions': 0,
|
||||
'running_scans': 0,
|
||||
'total_stop_signals': 0
|
||||
'total_stop_signals': 0,
|
||||
'active_socketio_connections': 0
|
||||
}
|
||||
|
||||
# Global session manager instance
|
||||
|
||||
Reference in New Issue
Block a user