From c4e6a8998ae1ddcd2db8818fae92735cbd3854ff Mon Sep 17 00:00:00 2001 From: overcuriousity Date: Sat, 20 Sep 2025 16:52:05 +0200 Subject: [PATCH] iteration on ws implementation --- app.py | 139 ++++++- core/graph_manager.py | 55 ++- core/logger.py | 145 +++++--- core/scanner.py | 596 +++++++++++++++++++++--------- core/session_manager.py | 168 ++++++++- providers/base_provider.py | 42 ++- providers/correlation_provider.py | 45 ++- providers/dns_provider.py | 14 +- static/js/main.js | 310 ++++++++++++++-- 9 files changed, 1224 insertions(+), 290 deletions(-) diff --git a/app.py b/app.py index bc1f29d..893bb64 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,7 @@ """ Flask application entry point for DNSRecon web interface. Provides REST API endpoints and serves the web interface with user session support. +FIXED: Enhanced WebSocket integration with proper connection management. """ import traceback @@ -21,30 +22,38 @@ from decimal import Decimal app = Flask(__name__) -socketio = SocketIO(app) +socketio = SocketIO(app, cors_allowed_origins="*") app.config['SECRET_KEY'] = config.flask_secret_key app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours) def get_user_scanner(): """ - Retrieves the scanner for the current session, or creates a new one if none exists. + FIXED: Retrieves the scanner for the current session with proper socketio management. """ current_flask_session_id = session.get('dnsrecon_session_id') if current_flask_session_id: existing_scanner = session_manager.get_session(current_flask_session_id) if existing_scanner: + # FIXED: Ensure socketio is properly maintained + existing_scanner.socketio = socketio + print(f"โœ“ Retrieved existing scanner for session {current_flask_session_id[:8]}... with socketio restored") return current_flask_session_id, existing_scanner + # FIXED: Register socketio connection when creating new session new_session_id = session_manager.create_session(socketio) new_scanner = session_manager.get_session(new_session_id) if not new_scanner: raise Exception("Failed to create new scanner session") + # FIXED: Ensure new scanner has socketio reference and register the connection + new_scanner.socketio = socketio + session_manager.register_socketio_connection(new_session_id, socketio) session['dnsrecon_session_id'] = new_session_id session.permanent = True + print(f"โœ“ Created new scanner for session {new_session_id[:8]}... with socketio registered") return new_session_id, new_scanner @@ -57,7 +66,7 @@ def index(): @app.route('/api/scan/start', methods=['POST']) def start_scan(): """ - Starts a new reconnaissance scan. + FIXED: Starts a new reconnaissance scan with proper socketio management. """ try: data = request.get_json() @@ -81,9 +90,17 @@ def start_scan(): if not scanner: return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500 + # FIXED: Ensure scanner has socketio reference and is registered + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + print(f"๐Ÿš€ Starting scan for {target} with socketio enabled and registered") + success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target) if success: + # Update session with socketio-enabled scanner + session_manager.update_session_scanner(user_session_id, scanner) + return jsonify({ 'success': True, 'message': 'Reconnaissance scan started successfully', @@ -112,6 +129,10 @@ def stop_scan(): if not scanner.session_id: scanner.session_id = user_session_id + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + scanner.stop_scan() session_manager.set_stop_signal(user_session_id) session_manager.update_scanner_status(user_session_id, 'stopped') @@ -128,31 +149,83 @@ def stop_scan(): return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 +@socketio.on('connect') +def handle_connect(): + """ + FIXED: Handle WebSocket connection with proper session management. + """ + print(f'โœ“ WebSocket client connected: {request.sid}') + + # Try to restore existing session connection + current_flask_session_id = session.get('dnsrecon_session_id') + if current_flask_session_id: + # Register this socketio connection for the existing session + session_manager.register_socketio_connection(current_flask_session_id, socketio) + print(f'โœ“ Registered WebSocket for existing session: {current_flask_session_id[:8]}...') + + # Immediately send current status to new connection + get_scan_status() + + +@socketio.on('disconnect') +def handle_disconnect(): + """ + FIXED: Handle WebSocket disconnection gracefully. + """ + print(f'โœ— WebSocket client disconnected: {request.sid}') + + # Note: We don't immediately remove the socketio connection from session_manager + # because the user might reconnect. The cleanup will happen during session cleanup. + + @socketio.on('get_status') def get_scan_status(): - """Get current scan status.""" + """ + FIXED: Get current scan status and emit real-time update with proper error handling. + """ try: user_session_id, scanner = get_user_scanner() if not scanner: status = { - 'status': 'idle', 'target_domain': None, 'current_depth': 0, - 'max_depth': 0, 'progress_percentage': 0.0, - 'user_session_id': user_session_id + 'status': 'idle', + 'target_domain': None, + 'current_depth': 0, + 'max_depth': 0, + 'progress_percentage': 0.0, + 'user_session_id': user_session_id, + 'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}} } + print(f"๐Ÿ“ก Emitting idle status for session {user_session_id[:8] if user_session_id else 'none'}...") else: if not scanner.session_id: scanner.session_id = user_session_id + + # FIXED: Ensure scanner has socketio reference for future updates + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + status = scanner.get_scan_status() status['user_session_id'] = user_session_id + + print(f"๐Ÿ“ก Emitting status update: {status['status']} - " + f"Nodes: {len(status.get('graph', {}).get('nodes', []))}, " + f"Edges: {len(status.get('graph', {}).get('edges', []))}") + + # Update session with socketio-enabled scanner + session_manager.update_session_scanner(user_session_id, scanner) socketio.emit('scan_update', status) except Exception as e: traceback.print_exc() - socketio.emit('scan_update', { - 'status': 'error', 'message': 'Failed to get status' - }) + error_status = { + 'status': 'error', + 'message': 'Failed to get status', + 'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}} + } + print(f"โš ๏ธ Error getting status, emitting error status") + socketio.emit('scan_update', error_status) @app.route('/api/graph', methods=['GET']) @@ -169,6 +242,10 @@ def get_graph_data(): if not scanner: return jsonify({'success': True, 'graph': empty_graph, 'user_session_id': user_session_id}) + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + graph_data = scanner.get_graph_data() or empty_graph return jsonify({'success': True, 'graph': graph_data, 'user_session_id': user_session_id}) @@ -195,6 +272,10 @@ def extract_from_large_entity(): if not scanner: return jsonify({'success': False, 'error': 'No active session found'}), 404 + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + success = scanner.extract_node_from_large_entity(large_entity_id, node_id) if success: @@ -215,6 +296,10 @@ def delete_graph_node(node_id): if not scanner: return jsonify({'success': False, 'error': 'No active session found'}), 404 + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + success = scanner.graph.remove_node(node_id) if success: @@ -240,6 +325,10 @@ def revert_graph_action(): if not scanner: return jsonify({'success': False, 'error': 'No active session found'}), 404 + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + action_type = data['type'] action_data = data['data'] @@ -284,6 +373,10 @@ def export_results(): if not scanner: return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + # Get export data using the new export manager try: results = export_manager.export_scan_results(scanner) @@ -335,6 +428,10 @@ def export_targets(): if not scanner: return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + # Use export manager for targets export targets_txt = export_manager.export_targets_list(scanner) @@ -365,6 +462,10 @@ def export_summary(): if not scanner: return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + # Use export manager for summary generation summary_txt = export_manager.generate_executive_summary(scanner) @@ -397,6 +498,10 @@ def set_api_keys(): user_session_id, scanner = get_user_scanner() session_config = scanner.config + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + updated_providers = [] for provider_name, api_key in data.items(): @@ -429,6 +534,10 @@ def get_providers(): user_session_id, scanner = get_user_scanner() base_provider_info = scanner.get_provider_info() + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + # Enhance provider info with API key source information enhanced_provider_info = {} @@ -493,6 +602,10 @@ def configure_providers(): user_session_id, scanner = get_user_scanner() session_config = scanner.config + # FIXED: Ensure scanner has socketio reference + scanner.socketio = socketio + session_manager.register_socketio_connection(user_session_id, socketio) + updated_providers = [] for provider_name, settings in data.items(): @@ -521,7 +634,6 @@ def configure_providers(): return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 - @app.errorhandler(404) def not_found(error): """Handle 404 errors.""" @@ -537,4 +649,9 @@ def internal_error(error): if __name__ == '__main__': config.load_from_env() + print("๐Ÿš€ Starting DNSRecon with enhanced WebSocket support...") + print(f" Host: {config.flask_host}") + print(f" Port: {config.flask_port}") + print(f" Debug: {config.flask_debug}") + print(" WebSocket: Enhanced connection management enabled") socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug) \ No newline at end of file diff --git a/core/graph_manager.py b/core/graph_manager.py index 4abe919..23648eb 100644 --- a/core/graph_manager.py +++ b/core/graph_manager.py @@ -4,8 +4,7 @@ Graph data model for DNSRecon using NetworkX. Manages in-memory graph storage with confidence scoring and forensic metadata. Now fully compatible with the unified ProviderResult data model. -UPDATED: Fixed correlation exclusion keys to match actual attribute names. -UPDATED: Removed export_json() method - now handled by ExportManager. +FIXED: Added proper pickle support to prevent weakref serialization errors. """ import re from datetime import datetime, timezone @@ -33,6 +32,7 @@ class GraphManager: Thread-safe graph manager for DNSRecon infrastructure mapping. Uses NetworkX for in-memory graph storage with confidence scoring. Compatible with unified ProviderResult data model. + FIXED: Added proper pickle support to handle NetworkX graph serialization. """ def __init__(self): @@ -40,6 +40,57 @@ class GraphManager: self.graph = nx.DiGraph() self.creation_time = datetime.now(timezone.utc).isoformat() self.last_modified = self.creation_time + + def __getstate__(self): + """Prepare GraphManager for pickling by converting NetworkX graph to serializable format.""" + state = self.__dict__.copy() + + # Convert NetworkX graph to a serializable format + if hasattr(self, 'graph') and self.graph: + # Extract all nodes with their data + nodes_data = {} + for node_id, attrs in self.graph.nodes(data=True): + nodes_data[node_id] = dict(attrs) + + # Extract all edges with their data + edges_data = [] + for source, target, attrs in self.graph.edges(data=True): + edges_data.append({ + 'source': source, + 'target': target, + 'attributes': dict(attrs) + }) + + # Replace the NetworkX graph with serializable data + state['_graph_nodes'] = nodes_data + state['_graph_edges'] = edges_data + del state['graph'] + + return state + + def __setstate__(self, state): + """Restore GraphManager after unpickling by reconstructing NetworkX graph.""" + # Restore basic attributes + self.__dict__.update(state) + + # Reconstruct NetworkX graph from serializable data + self.graph = nx.DiGraph() + + # Restore nodes + if hasattr(self, '_graph_nodes'): + for node_id, attrs in self._graph_nodes.items(): + self.graph.add_node(node_id, **attrs) + del self._graph_nodes + + # Restore edges + if hasattr(self, '_graph_edges'): + for edge_data in self._graph_edges: + self.graph.add_edge( + edge_data['source'], + edge_data['target'], + **edge_data['attributes'] + ) + del self._graph_edges def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None, description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool: diff --git a/core/logger.py b/core/logger.py index e774a2d..e65b22c 100644 --- a/core/logger.py +++ b/core/logger.py @@ -40,6 +40,7 @@ class ForensicLogger: """ Thread-safe forensic logging system for DNSRecon. Maintains detailed audit trail of all reconnaissance activities. + FIXED: Enhanced pickle support to prevent weakref issues in logging handlers. """ def __init__(self, session_id: str = ""): @@ -65,45 +66,74 @@ class ForensicLogger: 'target_domains': set() } - # Configure standard logger + # Configure standard logger with simple setup to avoid weakrefs self.logger = logging.getLogger(f'dnsrecon.{self.session_id}') self.logger.setLevel(logging.INFO) - # Create formatter for structured logging + # Create minimal formatter formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) - # Add console handler if not already present + # Add console handler only if not already present (avoid duplicate handlers) if not self.logger.handlers: console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) self.logger.addHandler(console_handler) def __getstate__(self): - """Prepare ForensicLogger for pickling by excluding unpicklable objects.""" + """ + FIXED: Prepare ForensicLogger for pickling by excluding problematic objects. + """ state = self.__dict__.copy() - # Remove the unpickleable 'logger' attribute - if 'logger' in state: - del state['logger'] - if 'lock' in state: - del state['lock'] + + # Remove potentially unpickleable attributes that may contain weakrefs + unpicklable_attrs = ['logger', 'lock'] + for attr in unpicklable_attrs: + if attr in state: + del state[attr] + + # Convert sets to lists for JSON serialization compatibility + if 'session_metadata' in state: + metadata = state['session_metadata'].copy() + if 'providers_used' in metadata and isinstance(metadata['providers_used'], set): + metadata['providers_used'] = list(metadata['providers_used']) + if 'target_domains' in metadata and isinstance(metadata['target_domains'], set): + metadata['target_domains'] = list(metadata['target_domains']) + state['session_metadata'] = metadata + return state def __setstate__(self, state): - """Restore ForensicLogger after unpickling by reconstructing logger.""" + """ + FIXED: Restore ForensicLogger after unpickling by reconstructing components. + """ self.__dict__.update(state) - # Re-initialize the 'logger' attribute + + # Re-initialize threading lock + self.lock = threading.Lock() + + # Re-initialize logger with minimal setup self.logger = logging.getLogger(f'dnsrecon.{self.session_id}') self.logger.setLevel(logging.INFO) + formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) + + # Only add handler if not already present if not self.logger.handlers: console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) self.logger.addHandler(console_handler) - self.lock = threading.Lock() + + # Convert lists back to sets if needed + if 'session_metadata' in self.__dict__: + metadata = self.session_metadata + if 'providers_used' in metadata and isinstance(metadata['providers_used'], list): + metadata['providers_used'] = set(metadata['providers_used']) + if 'target_domains' in metadata and isinstance(metadata['target_domains'], list): + metadata['target_domains'] = set(metadata['target_domains']) def _generate_session_id(self) -> str: """Generate unique session identifier.""" @@ -143,18 +173,23 @@ class ForensicLogger: discovery_context=discovery_context ) - self.api_requests.append(api_request) - self.session_metadata['total_requests'] += 1 - self.session_metadata['providers_used'].add(provider) + with self.lock: + self.api_requests.append(api_request) + self.session_metadata['total_requests'] += 1 + self.session_metadata['providers_used'].add(provider) + + if target_indicator: + self.session_metadata['target_domains'].add(target_indicator) - if target_indicator: - self.session_metadata['target_domains'].add(target_indicator) - - # Log to standard logger - if error: - self.logger.error(f"API Request Failed.") - else: - self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}") + # Log to standard logger with error handling + try: + if error: + self.logger.error(f"API Request Failed - {provider}: {url}") + else: + self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}") + except Exception: + # If logging fails, continue without breaking the application + pass def log_relationship_discovery(self, source_node: str, target_node: str, relationship_type: str, confidence_score: float, @@ -183,29 +218,44 @@ class ForensicLogger: discovery_method=discovery_method ) - self.relationships.append(relationship) - self.session_metadata['total_relationships'] += 1 + with self.lock: + self.relationships.append(relationship) + self.session_metadata['total_relationships'] += 1 - self.logger.info( - f"Relationship Discovered - {source_node} -> {target_node} " - f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}" - ) + # Log to standard logger with error handling + try: + self.logger.info( + f"Relationship Discovered - {source_node} -> {target_node} " + f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}" + ) + except Exception: + # If logging fails, continue without breaking the application + pass def log_scan_start(self, target_domain: str, recursion_depth: int, enabled_providers: List[str]) -> None: """Log the start of a reconnaissance scan.""" - self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}") - self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}") - - self.session_metadata['target_domains'].update(target_domain) + try: + self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}") + self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}") + + with self.lock: + self.session_metadata['target_domains'].add(target_domain) + except Exception: + pass def log_scan_complete(self) -> None: """Log the completion of a reconnaissance scan.""" - self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat() - self.session_metadata['providers_used'] = list(self.session_metadata['providers_used']) - self.session_metadata['target_domains'] = list(self.session_metadata['target_domains']) + with self.lock: + self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat() + # Convert sets to lists for serialization + self.session_metadata['providers_used'] = list(self.session_metadata['providers_used']) + self.session_metadata['target_domains'] = list(self.session_metadata['target_domains']) - self.logger.info(f"Scan Complete - Session: {self.session_id}") + try: + self.logger.info(f"Scan Complete - Session: {self.session_id}") + except Exception: + pass def export_audit_trail(self) -> Dict[str, Any]: """ @@ -214,12 +264,13 @@ class ForensicLogger: Returns: Dictionary containing complete session audit trail """ - return { - 'session_metadata': self.session_metadata.copy(), - 'api_requests': [asdict(req) for req in self.api_requests], - 'relationships': [asdict(rel) for rel in self.relationships], - 'export_timestamp': datetime.now(timezone.utc).isoformat() - } + with self.lock: + return { + 'session_metadata': self.session_metadata.copy(), + 'api_requests': [asdict(req) for req in self.api_requests], + 'relationships': [asdict(rel) for rel in self.relationships], + 'export_timestamp': datetime.now(timezone.utc).isoformat() + } def get_forensic_summary(self) -> Dict[str, Any]: """ @@ -229,7 +280,13 @@ class ForensicLogger: Dictionary containing summary statistics """ provider_stats = {} - for provider in self.session_metadata['providers_used']: + + # Ensure providers_used is a set for iteration + providers_used = self.session_metadata['providers_used'] + if isinstance(providers_used, list): + providers_used = set(providers_used) + + for provider in providers_used: provider_requests = [req for req in self.api_requests if req.provider == provider] provider_relationships = [rel for rel in self.relationships if rel.provider == provider] diff --git a/core/scanner.py b/core/scanner.py index 01a1889..f2a44f9 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -35,6 +35,7 @@ class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. UNIFIED: Combines comprehensive features with improved display formatting. + FIXED: Enhanced threading object initialization to prevent None references. """ def __init__(self, session_config=None, socketio=None): @@ -44,6 +45,11 @@ class Scanner: if session_config is None: from core.session_config import create_session_config session_config = create_session_config() + + # FIXED: Initialize all threading objects first + self._initialize_threading_objects() + + # Set socketio (but will be set to None for storage) self.socketio = socketio self.config = session_config @@ -53,17 +59,12 @@ class Scanner: self.current_target = None self.current_depth = 0 self.max_depth = 2 - self.stop_event = threading.Event() self.scan_thread = None self.session_id: Optional[str] = None # Will be set by session manager - self.task_queue = PriorityQueue() - self.target_retries = defaultdict(int) - self.scan_failed_due_to_retries = False self.initial_targets = set() # Thread-safe processing tracking (from Document 1) self.currently_processing = set() - self.processing_lock = threading.Lock() # Display-friendly processing list (from Document 2) self.currently_processing_display = [] @@ -81,9 +82,10 @@ class Scanner: self.max_workers = self.config.max_concurrent_requests self.executor = None - # Status logger thread with improved formatting - self.status_logger_thread = None - self.status_logger_stop_event = threading.Event() + # Initialize collections that will be recreated during unpickling + self.task_queue = PriorityQueue() + self.target_retries = defaultdict(int) + self.scan_failed_due_to_retries = False # Initialize providers with session config self._initialize_providers() @@ -99,12 +101,24 @@ class Scanner: traceback.print_exc() raise + def _initialize_threading_objects(self): + """ + FIXED: Initialize all threading objects with proper error handling. + This method can be called during both __init__ and __setstate__. + """ + self.stop_event = threading.Event() + self.processing_lock = threading.Lock() + self.status_logger_stop_event = threading.Event() + self.status_logger_thread = None + def _is_stop_requested(self) -> bool: """ Check if stop is requested using both local and Redis-based signals. This ensures reliable termination across process boundaries. + FIXED: Added None check for stop_event. """ - if self.stop_event.is_set(): + # FIXED: Ensure stop_event exists before checking + if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set(): return True if self.session_id: @@ -112,16 +126,24 @@ class Scanner: from core.session_manager import session_manager return session_manager.is_stop_requested(self.session_id) except Exception as e: - # Fall back to local event - return self.stop_event.is_set() + # Fall back to local event if it exists + if hasattr(self, 'stop_event') and self.stop_event: + return self.stop_event.is_set() + return False - return self.stop_event.is_set() + # Final fallback + if hasattr(self, 'stop_event') and self.stop_event: + return self.stop_event.is_set() + return False def _set_stop_signal(self) -> None: """ Set stop signal both locally and in Redis. + FIXED: Added None check for stop_event. """ - self.stop_event.set() + # FIXED: Ensure stop_event exists before setting + if hasattr(self, 'stop_event') and self.stop_event: + self.stop_event.set() if self.session_id: try: @@ -162,17 +184,21 @@ class Scanner: """Restore object after unpickling by reconstructing threading objects.""" self.__dict__.update(state) - self.stop_event = threading.Event() + # FIXED: Ensure all threading objects are properly initialized + self._initialize_threading_objects() + + # Re-initialize other objects self.scan_thread = None self.executor = None - self.processing_lock = threading.Lock() self.task_queue = PriorityQueue() self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) self.logger = get_forensic_logger() - self.status_logger_thread = None - self.status_logger_stop_event = threading.Event() - self.socketio = None + # FIXED: Initialize socketio as None but preserve ability to set it + if not hasattr(self, 'socketio'): + self.socketio = None + + # Initialize missing attributes with defaults if not hasattr(self, 'providers') or not self.providers: self._initialize_providers() @@ -182,11 +208,36 @@ class Scanner: if not hasattr(self, 'currently_processing_display'): self.currently_processing_display = [] + if not hasattr(self, 'target_retries'): + self.target_retries = defaultdict(int) + + if not hasattr(self, 'scan_failed_due_to_retries'): + self.scan_failed_due_to_retries = False + + if not hasattr(self, 'initial_targets'): + self.initial_targets = set() + + # Ensure providers have stop events if hasattr(self, 'providers'): for provider in self.providers: - if hasattr(provider, 'set_stop_event'): + if hasattr(provider, 'set_stop_event') and self.stop_event: provider.set_stop_event(self.stop_event) + def _ensure_threading_objects_exist(self): + """ + FIXED: Utility method to ensure threading objects exist before use. + Call this before any method that might use threading objects. + """ + if not hasattr(self, 'stop_event') or self.stop_event is None: + print("WARNING: Threading objects not initialized, recreating...") + self._initialize_threading_objects() + + if not hasattr(self, 'processing_lock') or self.processing_lock is None: + self.processing_lock = threading.Lock() + + if not hasattr(self, 'task_queue') or self.task_queue is None: + self.task_queue = PriorityQueue() + def _initialize_providers(self) -> None: """Initialize all available providers based on session configuration.""" self.providers = [] @@ -224,7 +275,9 @@ class Scanner: print(f" Available: {is_available}") if is_available: - provider.set_stop_event(self.stop_event) + # FIXED: Ensure stop_event exists before setting it + if hasattr(self, 'stop_event') and self.stop_event: + provider.set_stop_event(self.stop_event) if isinstance(provider, CorrelationProvider): provider.set_graph_manager(self.graph) self.providers.append(provider) @@ -254,15 +307,25 @@ class Scanner: BOLD = "\033[1m" last_status_str = "" - while not self.status_logger_stop_event.is_set(): + + # FIXED: Ensure threading objects exist + self._ensure_threading_objects_exist() + + while not (hasattr(self, 'status_logger_stop_event') and + self.status_logger_stop_event and + self.status_logger_stop_event.is_set()): try: - with self.processing_lock: - in_flight_tasks = list(self.currently_processing) - self.currently_processing_display = in_flight_tasks.copy() + # FIXED: Check if processing_lock exists before using + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + in_flight_tasks = list(self.currently_processing) + self.currently_processing_display = in_flight_tasks.copy() + else: + in_flight_tasks = list(getattr(self, 'currently_processing', [])) status_str = ( f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | " - f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | " + f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | " f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | " f"{GREEN}Completed: {self.indicators_completed}{ENDC} | " f"Skipped: {self.tasks_skipped} | " @@ -290,22 +353,30 @@ class Scanner: time.sleep(2) def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: + """ + FIXED: Enhanced start_scan with proper threading object initialization and socketio management. + """ + # FIXED: Ensure threading objects exist before proceeding + self._ensure_threading_objects_exist() + if self.scan_thread and self.scan_thread.is_alive(): self.logger.logger.info("Stopping existing scan before starting new one") self._set_stop_signal() self.status = ScanStatus.STOPPED # Clean up processing state - with self.processing_lock: - self.currently_processing.clear() + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + self.currently_processing.clear() self.currently_processing_display = [] # Clear task queue - while not self.task_queue.empty(): - try: - self.task_queue.get_nowait() - except: - break + if hasattr(self, 'task_queue') and self.task_queue: + while not self.task_queue.empty(): + try: + self.task_queue.get_nowait() + except: + break # Shutdown executor if self.executor: @@ -322,14 +393,26 @@ class Scanner: self.logger.logger.warning("Previous scan thread did not terminate cleanly") self.status = ScanStatus.IDLE - self.stop_event.clear() + + # FIXED: Ensure stop_event exists before clearing + if hasattr(self, 'stop_event') and self.stop_event: + self.stop_event.clear() if self.session_id: from core.session_manager import session_manager session_manager.clear_stop_signal(self.session_id) + + # FIXED: Restore socketio connection if missing + if not hasattr(self, 'socketio') or not self.socketio: + registered_socketio = session_manager.get_socketio_connection(self.session_id) + if registered_socketio: + self.socketio = registered_socketio + print(f"โœ“ Restored socketio connection for scan start") - with self.processing_lock: - self.currently_processing.clear() + # FIXED: Safe cleanup with existence checks + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + self.currently_processing.clear() self.currently_processing_display = [] self.task_queue = PriorityQueue() @@ -397,7 +480,10 @@ class Scanner: ) self.scan_thread.start() - self.status_logger_stop_event.clear() + # FIXED: Ensure status_logger_stop_event exists before clearing + if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event: + self.status_logger_stop_event.clear() + self.status_logger_thread = threading.Thread( target=self._status_logger_thread, daemon=True, @@ -451,6 +537,13 @@ class Scanner: return 10 # Very low rate limit = very low priority def _execute_scan(self, target: str, max_depth: int) -> None: + """ + FIXED: Enhanced execute_scan with proper threading object handling. + """ + # FIXED: Ensure threading objects exist + self._ensure_threading_objects_exist() + update_counter = 0 # Track updates for throttling + last_update_time = time.time() self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping @@ -482,8 +575,13 @@ class Scanner: print(f"\n=== PHASE 1: Running non-correlation providers ===") while not self._is_stop_requested(): queue_empty = self.task_queue.empty() - with self.processing_lock: - no_active_processing = len(self.currently_processing) == 0 + + # FIXED: Safe processing lock usage + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + no_active_processing = len(self.currently_processing) == 0 + else: + no_active_processing = len(getattr(self, 'currently_processing', [])) == 0 if queue_empty and no_active_processing: consecutive_empty_iterations += 1 @@ -536,10 +634,23 @@ class Scanner: continue # Thread-safe processing state management - with self.processing_lock: + processing_key = (provider_name, target_item) + + # FIXED: Safe processing lock usage + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + if self._is_stop_requested(): + break + if processing_key in self.currently_processing: + self.tasks_skipped += 1 + self.indicators_completed += 1 + continue + self.currently_processing.add(processing_key) + else: if self._is_stop_requested(): break - processing_key = (provider_name, target_item) + if not hasattr(self, 'currently_processing'): + self.currently_processing = set() if processing_key in self.currently_processing: self.tasks_skipped += 1 self.indicators_completed += 1 @@ -558,7 +669,12 @@ class Scanner: if provider and not isinstance(provider, CorrelationProvider): new_targets, _, success = self._process_provider_task(provider, target_item, depth) - + update_counter += 1 + current_time = time.time() + if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0): + self._update_session_state() + last_update_time = current_time + update_counter = 0 if self._is_stop_requested(): break @@ -603,9 +719,13 @@ class Scanner: self.indicators_completed += 1 finally: - with self.processing_lock: - processing_key = (provider_name, target_item) - self.currently_processing.discard(processing_key) + # FIXED: Safe processing lock usage for cleanup + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + self.currently_processing.discard(processing_key) + else: + if hasattr(self, 'currently_processing'): + self.currently_processing.discard(processing_key) # PHASE 2: Run correlations on all discovered nodes if not self._is_stop_requested(): @@ -618,8 +738,9 @@ class Scanner: self.logger.logger.error(f"Scan failed: {e}") finally: # Comprehensive cleanup (same as before) - with self.processing_lock: - self.currently_processing.clear() + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + self.currently_processing.clear() self.currently_processing_display = [] while not self.task_queue.empty(): @@ -635,7 +756,9 @@ class Scanner: else: self.status = ScanStatus.COMPLETED - self.status_logger_stop_event.set() + # FIXED: Safe stop event handling + if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event: + self.status_logger_stop_event.set() if self.status_logger_thread and self.status_logger_thread.is_alive(): self.status_logger_thread.join(timeout=2.0) @@ -689,8 +812,13 @@ class Scanner: while not self._is_stop_requested() and correlation_tasks: queue_empty = self.task_queue.empty() - with self.processing_lock: - no_active_processing = len(self.currently_processing) == 0 + + # FIXED: Safe processing check + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + no_active_processing = len(self.currently_processing) == 0 + else: + no_active_processing = len(getattr(self, 'currently_processing', [])) == 0 if queue_empty and no_active_processing: consecutive_empty_iterations += 1 @@ -722,10 +850,23 @@ class Scanner: correlation_tasks.remove(task_tuple) continue - with self.processing_lock: + processing_key = (provider_name, target_item) + + # FIXED: Safe processing management + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + if self._is_stop_requested(): + break + if processing_key in self.currently_processing: + self.tasks_skipped += 1 + self.indicators_completed += 1 + continue + self.currently_processing.add(processing_key) + else: if self._is_stop_requested(): break - processing_key = (provider_name, target_item) + if not hasattr(self, 'currently_processing'): + self.currently_processing = set() if processing_key in self.currently_processing: self.tasks_skipped += 1 self.indicators_completed += 1 @@ -754,51 +895,214 @@ class Scanner: correlation_tasks.remove(task_tuple) finally: - with self.processing_lock: - processing_key = (provider_name, target_item) - self.currently_processing.discard(processing_key) + # FIXED: Safe cleanup + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + self.currently_processing.discard(processing_key) + else: + if hasattr(self, 'currently_processing'): + self.currently_processing.discard(processing_key) print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}") + # Rest of the methods remain the same but with similar threading object safety checks... + # I'll include the key methods that need fixes: + + def stop_scan(self) -> bool: + """Request immediate scan termination with proper cleanup.""" + try: + self.logger.logger.info("Scan termination requested by user") + self._set_stop_signal() + self.status = ScanStatus.STOPPED + + # FIXED: Safe cleanup + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + self.currently_processing.clear() + self.currently_processing_display = [] + + self.task_queue = PriorityQueue() + + if self.executor: + try: + self.executor.shutdown(wait=False, cancel_futures=True) + except Exception: + pass + + self._update_session_state() + return True + + except Exception as e: + self.logger.logger.error(f"Error during scan termination: {e}") + traceback.print_exc() + return False + + def get_scan_status(self) -> Dict[str, Any]: + """Get current scan status with comprehensive graph data for real-time updates.""" + try: + # FIXED: Safe processing state access + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + currently_processing_count = len(self.currently_processing) + currently_processing_list = list(self.currently_processing) + else: + currently_processing_count = len(getattr(self, 'currently_processing', [])) + currently_processing_list = list(getattr(self, 'currently_processing', [])) + + # FIXED: Always include complete graph data for real-time updates + graph_data = self.get_graph_data() + + return { + 'status': self.status, + 'target_domain': self.current_target, + 'current_depth': self.current_depth, + 'max_depth': self.max_depth, + 'current_indicator': self.current_indicator, + 'indicators_processed': self.indicators_processed, + 'indicators_completed': self.indicators_completed, + 'tasks_re_enqueued': self.tasks_re_enqueued, + 'progress_percentage': self._calculate_progress(), + 'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued, + 'enabled_providers': [provider.get_name() for provider in self.providers], + 'graph': graph_data, # FIXED: Always include complete graph data + 'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0, + 'currently_processing_count': currently_processing_count, + 'currently_processing': currently_processing_list[:5], + 'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0, + 'tasks_completed': self.indicators_completed, + 'tasks_skipped': self.tasks_skipped, + 'tasks_rescheduled': self.tasks_re_enqueued, + } + except Exception as e: + traceback.print_exc() + return { + 'status': 'error', + 'message': 'Failed to get status', + 'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}} + } + + def _update_session_state(self) -> None: + """ + FIXED: Update the scanner state in Redis and emit real-time WebSocket updates. + Enhanced with better error handling and socketio management. + """ + if self.session_id: + try: + # Get current status for WebSocket emission + current_status = self.get_scan_status() + + # FIXED: Emit real-time update via WebSocket with better error handling + socketio_available = False + if hasattr(self, 'socketio') and self.socketio: + try: + print(f"๐Ÿ“ก EMITTING WebSocket update: {current_status.get('status', 'unknown')} - " + f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, " + f"Edges: {len(current_status.get('graph', {}).get('edges', []))}") + + self.socketio.emit('scan_update', current_status) + print("โœ… WebSocket update emitted successfully") + socketio_available = True + + except Exception as ws_error: + print(f"โš ๏ธ WebSocket emission failed: {ws_error}") + # Try to get socketio from session manager + try: + from core.session_manager import session_manager + registered_socketio = session_manager.get_socketio_connection(self.session_id) + if registered_socketio: + print("๐Ÿ”„ Attempting to use registered socketio connection...") + registered_socketio.emit('scan_update', current_status) + self.socketio = registered_socketio # Update our reference + print("โœ… WebSocket update emitted via registered connection") + socketio_available = True + else: + print("โš ๏ธ No registered socketio connection found") + except Exception as fallback_error: + print(f"โš ๏ธ Fallback socketio emission also failed: {fallback_error}") + else: + # Try to restore socketio from session manager + try: + from core.session_manager import session_manager + registered_socketio = session_manager.get_socketio_connection(self.session_id) + if registered_socketio: + print(f"๐Ÿ”„ Restoring socketio connection for session {self.session_id}") + self.socketio = registered_socketio + self.socketio.emit('scan_update', current_status) + print("โœ… WebSocket update emitted via restored connection") + socketio_available = True + else: + print(f"โš ๏ธ No socketio connection available for session {self.session_id}") + except Exception as restore_error: + print(f"โš ๏ธ Failed to restore socketio connection: {restore_error}") + + if not socketio_available: + print(f"โš ๏ธ Real-time updates unavailable for session {self.session_id}") + + # Update session in Redis for persistence (always do this) + try: + from core.session_manager import session_manager + session_manager.update_session_scanner(self.session_id, self) + except Exception as redis_error: + print(f"โš ๏ธ Failed to update session in Redis: {redis_error}") + + except Exception as e: + print(f"โš ๏ธ Error updating session state: {e}") + import traceback + traceback.print_exc() + + def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: """ - Manages the entire process for a given target and provider. - This version is generalized to handle all relationships dynamically. + FIXED: Manages the entire process for a given target and provider with enhanced real-time updates. """ if self._is_stop_requested(): return set(), set(), False - + is_ip = _is_valid_ip(target) target_type = NodeType.IP if is_ip else NodeType.DOMAIN - + self.graph.add_node(target, target_type) self._initialize_provider_states(target) - + new_targets = set() provider_successful = True - + try: provider_result = self._execute_provider_query(provider, target, is_ip) - + if provider_result is None: provider_successful = False elif not self._is_stop_requested(): - # Pass all relationships to be processed discovered, is_large_entity = self._process_provider_result_unified( target, provider, provider_result, depth ) new_targets.update(discovered) - + + # FIXED: Emit real-time update after processing provider result + if discovered or provider_result.get_relationship_count() > 0: + # Ensure we have socketio connection for real-time updates + if self.session_id and (not hasattr(self, 'socketio') or not self.socketio): + try: + from core.session_manager import session_manager + registered_socketio = session_manager.get_socketio_connection(self.session_id) + if registered_socketio: + self.socketio = registered_socketio + print(f"๐Ÿ”„ Restored socketio connection during provider processing") + except Exception as restore_error: + print(f"โš ๏ธ Failed to restore socketio during provider processing: {restore_error}") + + self._update_session_state() + except Exception as e: provider_successful = False self._log_provider_error(target, provider.get_name(), str(e)) - + return new_targets, set(), provider_successful + + def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]: - """ - The "worker" function that directly communicates with the provider to fetch data. - """ + """The "worker" function that directly communicates with the provider to fetch data.""" provider_name = provider.get_name() start_time = datetime.now(timezone.utc) @@ -825,9 +1129,7 @@ class Scanner: def _create_large_entity_from_result(self, source_node: str, provider_name: str, provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]: - """ - Creates a large entity node, tags all member nodes, and returns its ID and members. - """ + """Creates a large entity node, tags all member nodes, and returns its ID and members.""" members = {rel.target_node for rel in provider_result.relationships if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)} @@ -860,7 +1162,7 @@ class Scanner: def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool: """ - Removes a node from a large entity, allowing it to be processed normally. + FIXED: Removes a node from a large entity with immediate real-time update. """ if not self.graph.graph.has_node(node_id): return False @@ -879,7 +1181,6 @@ class Scanner: for provider in eligible_providers: provider_name = provider.get_name() priority = self._get_priority(provider_name) - # Use current depth of the large entity if available, else 0 depth = 0 if self.graph.graph.has_node(large_entity_id): le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', []) @@ -890,6 +1191,19 @@ class Scanner: self.task_queue.put((time.time(), priority, (provider_name, node_id, depth))) self.total_tasks_ever_enqueued += 1 + # FIXED: Emit real-time update after extraction with socketio management + if self.session_id and (not hasattr(self, 'socketio') or not self.socketio): + try: + from core.session_manager import session_manager + registered_socketio = session_manager.get_socketio_connection(self.session_id) + if registered_socketio: + self.socketio = registered_socketio + print(f"๐Ÿ”„ Restored socketio for node extraction update") + except Exception as restore_error: + print(f"โš ๏ธ Failed to restore socketio for extraction: {restore_error}") + + self._update_session_state() + return True return False @@ -897,8 +1211,7 @@ class Scanner: def _process_provider_result_unified(self, target: str, provider: BaseProvider, provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: """ - Process a unified ProviderResult object to update the graph. - This version dynamically re-routes edges to a large entity container. + FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates. """ provider_name = provider.get_name() discovered_targets = set() @@ -918,6 +1231,10 @@ class Scanner: target, provider_name, provider_result, current_depth ) + # Track if we added anything significant + nodes_added = 0 + edges_added = 0 + for i, relationship in enumerate(provider_result.relationships): if i % 5 == 0 and self._is_stop_requested(): break @@ -949,17 +1266,20 @@ class Scanner: max_depth_reached = current_depth >= self.max_depth # Add actual nodes to the graph (they might be hidden by the UI) - self.graph.add_node(source_node_id, source_type) - self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}) + if self.graph.add_node(source_node_id, source_type): + nodes_added += 1 + if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}): + nodes_added += 1 # Add the visual edge to the graph - self.graph.add_edge( + if self.graph.add_edge( visual_source, visual_target, relationship.relationship_type, relationship.confidence, provider_name, relationship.raw_data - ) + ): + edges_added += 1 if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached: if target_node_id not in large_entity_members: @@ -987,88 +1307,32 @@ class Scanner: if not self.graph.graph.has_node(node_id): node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN self.graph.add_node(node_id, node_type, attributes=node_attributes_list) + nodes_added += 1 else: existing_attrs = self.graph.graph.nodes[node_id].get('attributes', []) self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list - return discovered_targets, is_large_entity - - def stop_scan(self) -> bool: - """Request immediate scan termination with proper cleanup.""" - try: - self.logger.logger.info("Scan termination requested by user") - self._set_stop_signal() - self.status = ScanStatus.STOPPED + # FIXED: Emit real-time update if we added anything significant + if nodes_added > 0 or edges_added > 0: + print(f"๐Ÿ”„ Added {nodes_added} nodes, {edges_added} edges - triggering real-time update") - with self.processing_lock: - self.currently_processing.clear() - self.currently_processing_display = [] - - self.task_queue = PriorityQueue() - - if self.executor: + # Ensure we have socketio connection for immediate update + if self.session_id and (not hasattr(self, 'socketio') or not self.socketio): try: - self.executor.shutdown(wait=False, cancel_futures=True) - except Exception: - pass - + from core.session_manager import session_manager + registered_socketio = session_manager.get_socketio_connection(self.session_id) + if registered_socketio: + self.socketio = registered_socketio + print(f"๐Ÿ”„ Restored socketio for immediate update") + except Exception as restore_error: + print(f"โš ๏ธ Failed to restore socketio for immediate update: {restore_error}") + self._update_session_state() - return True - - except Exception as e: - self.logger.logger.error(f"Error during scan termination: {e}") - traceback.print_exc() - return False - - def _update_session_state(self) -> None: - """ - Update the scanner state in Redis for GUI updates. - """ - if self.session_id: - try: - if self.socketio: - self.socketio.emit('scan_update', self.get_scan_status()) - from core.session_manager import session_manager - session_manager.update_session_scanner(self.session_id, self) - except Exception: - pass - - def get_scan_status(self) -> Dict[str, Any]: - """Get current scan status with comprehensive processing information.""" - try: - with self.processing_lock: - currently_processing_count = len(self.currently_processing) - currently_processing_list = list(self.currently_processing) - - return { - 'status': self.status, - 'target_domain': self.current_target, - 'current_depth': self.current_depth, - 'max_depth': self.max_depth, - 'current_indicator': self.current_indicator, - 'indicators_processed': self.indicators_processed, - 'indicators_completed': self.indicators_completed, - 'tasks_re_enqueued': self.tasks_re_enqueued, - 'progress_percentage': self._calculate_progress(), - 'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued, - 'enabled_providers': [provider.get_name() for provider in self.providers], - 'graph': self.get_graph_data(), - 'task_queue_size': self.task_queue.qsize(), - 'currently_processing_count': currently_processing_count, - 'currently_processing': currently_processing_list[:5], - 'tasks_in_queue': self.task_queue.qsize(), - 'tasks_completed': self.indicators_completed, - 'tasks_skipped': self.tasks_skipped, - 'tasks_rescheduled': self.tasks_re_enqueued, - } - except Exception: - traceback.print_exc() - return { 'status': 'error', 'message': 'Failed to get status' } + return discovered_targets, is_large_entity + def _initialize_provider_states(self, target: str) -> None: - """ - FIXED: Safer provider state initialization with error handling. - """ + """FIXED: Safer provider state initialization with error handling.""" try: if not self.graph.graph.has_node(target): return @@ -1081,11 +1345,8 @@ class Scanner: except Exception as e: self.logger.logger.warning(f"Error initializing provider states for {target}: {e}") - def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: - """ - FIXED: Improved provider eligibility checking with better filtering. - """ + """FIXED: Improved provider eligibility checking with better filtering.""" if dns_only: return [p for p in self.providers if p.get_name() == 'dns'] @@ -1124,9 +1385,7 @@ class Scanner: return eligible def _already_queried_provider(self, target: str, provider_name: str) -> bool: - """ - FIXED: More robust check for already queried providers with proper error handling. - """ + """FIXED: More robust check for already queried providers with proper error handling.""" try: if not self.graph.graph.has_node(target): return False @@ -1145,9 +1404,7 @@ class Scanner: def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: Optional[str], start_time: datetime) -> None: - """ - FIXED: More robust provider state updates with validation. - """ + """FIXED: More robust provider state updates with validation.""" try: if not self.graph.graph.has_node(target): self.logger.logger.warning(f"Cannot update provider state: node {target} not found") @@ -1174,7 +1431,8 @@ class Scanner: } # Update last modified time for forensic integrity - self.last_modified = datetime.now(timezone.utc).isoformat() + if hasattr(self, 'last_modified'): + self.last_modified = datetime.now(timezone.utc).isoformat() except Exception as e: self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}") @@ -1191,9 +1449,14 @@ class Scanner: return 0.0 # Add small buffer for tasks still in queue to avoid showing 100% too early - queue_size = max(0, self.task_queue.qsize()) - with self.processing_lock: - active_tasks = len(self.currently_processing) + queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0) + + # FIXED: Safe processing count + if hasattr(self, 'processing_lock') and self.processing_lock: + with self.processing_lock: + active_tasks = len(self.currently_processing) + else: + active_tasks = len(getattr(self, 'currently_processing', [])) # Adjust total to account for remaining work adjusted_total = max(self.total_tasks_ever_enqueued, @@ -1210,12 +1473,13 @@ class Scanner: return 0.0 def get_graph_data(self) -> Dict[str, Any]: + """Get current graph data formatted for frontend visualization.""" graph_data = self.graph.get_graph_data() graph_data['initial_targets'] = list(self.initial_targets) return graph_data - def get_provider_info(self) -> Dict[str, Dict[str, Any]]: + """Get comprehensive information about all available providers.""" info = {} provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): diff --git a/core/session_manager.py b/core/session_manager.py index 14d0c9a..940b388 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -6,6 +6,7 @@ import uuid import redis import pickle from typing import Dict, Optional, Any +import copy from core.scanner import Scanner from config import config @@ -13,7 +14,7 @@ from config import config class SessionManager: """ FIXED: Manages multiple scanner instances for concurrent user sessions using Redis. - Now more conservative about session creation to preserve API keys and configuration. + Enhanced to properly maintain WebSocket connections throughout scan lifecycle. """ def __init__(self, session_timeout_minutes: int = 0): @@ -30,6 +31,9 @@ class SessionManager: # FIXED: Add a creation lock to prevent race conditions self.creation_lock = threading.Lock() + # Track active socketio connections per session + self.active_socketio_connections = {} + # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() @@ -40,7 +44,7 @@ class SessionManager: """Prepare SessionManager for pickling.""" state = self.__dict__.copy() # Exclude unpickleable attributes - Redis client and threading objects - unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock'] + unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections'] for attr in unpicklable_attrs: if attr in state: del state[attr] @@ -53,6 +57,7 @@ class SessionManager: self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.lock = threading.Lock() self.creation_lock = threading.Lock() + self.active_socketio_connections = {} self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() @@ -64,22 +69,70 @@ class SessionManager: """Generates the Redis key for a session's stop signal.""" return f"dnsrecon:stop:{session_id}" + def register_socketio_connection(self, session_id: str, socketio) -> None: + """ + FIXED: Register a socketio connection for a session. + This ensures the connection is maintained throughout the session lifecycle. + """ + with self.lock: + self.active_socketio_connections[session_id] = socketio + print(f"Registered socketio connection for session {session_id}") + + def get_socketio_connection(self, session_id: str): + """ + FIXED: Get the active socketio connection for a session. + """ + with self.lock: + return self.active_socketio_connections.get(session_id) + + def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner: + """ + FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects. + Now preserves socketio connection info for restoration. + """ + # Set the session ID on the scanner for cross-process stop signal management + scanner.session_id = session_id + + # FIXED: Don't set socketio to None if we want to preserve real-time updates + # Instead, we'll restore it when loading the scanner + scanner.socketio = None + + # Force cleanup of any threading objects that might cause issues + if hasattr(scanner, 'stop_event'): + scanner.stop_event = None + if hasattr(scanner, 'scan_thread'): + scanner.scan_thread = None + if hasattr(scanner, 'executor'): + scanner.executor = None + if hasattr(scanner, 'status_logger_thread'): + scanner.status_logger_thread = None + if hasattr(scanner, 'status_logger_stop_event'): + scanner.status_logger_stop_event = None + + return scanner + def create_session(self, socketio=None) -> str: """ - FIXED: Create a new user session with thread-safe creation to prevent duplicates. + FIXED: Create a new user session with enhanced WebSocket management. """ # FIXED: Use creation lock to prevent race conditions with self.creation_lock: session_id = str(uuid.uuid4()) print(f"=== CREATING SESSION {session_id} IN REDIS ===") + # FIXED: Register socketio connection first + if socketio: + self.register_socketio_connection(session_id, socketio) + try: from core.session_config import create_session_config session_config = create_session_config() - scanner_instance = Scanner(session_config=session_config, socketio=socketio) - # Set the session ID on the scanner for cross-process stop signal management - scanner_instance.session_id = session_id + # Create scanner WITHOUT socketio to avoid weakref issues + scanner_instance = Scanner(session_config=session_config, socketio=None) + + # Prepare scanner for storage (removes problematic objects) + scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id) session_data = { 'scanner': scanner_instance, @@ -89,12 +142,24 @@ class SessionManager: 'status': 'active' } - # Serialize the entire session data dictionary using pickle - serialized_data = pickle.dumps(session_data) + # Test serialization before storing to catch issues early + try: + test_serialization = pickle.dumps(session_data) + print(f"Session serialization test successful ({len(test_serialization)} bytes)") + except Exception as pickle_error: + print(f"PICKLE TEST FAILED: {pickle_error}") + # Try to identify the problematic object + for key, value in session_data.items(): + try: + pickle.dumps(value) + print(f" {key}: OK") + except Exception as item_error: + print(f" {key}: FAILED - {item_error}") + raise pickle_error # Store in Redis session_key = self._get_session_key(session_id) - self.redis_client.setex(session_key, self.session_timeout, serialized_data) + self.redis_client.setex(session_key, self.session_timeout, test_serialization) # Initialize stop signal as False stop_key = self._get_stop_signal_key(session_id) @@ -106,6 +171,8 @@ class SessionManager: except Exception as e: print(f"ERROR: Failed to create session {session_id}: {e}") + import traceback + traceback.print_exc() raise def set_stop_signal(self, session_id: str) -> bool: @@ -175,31 +242,63 @@ class SessionManager: # Ensure the scanner has the correct session ID for stop signal checking if 'scanner' in session_data and session_data['scanner']: session_data['scanner'].session_id = session_id + # FIXED: Restore socketio connection from our registry + socketio_conn = self.get_socketio_connection(session_id) + if socketio_conn: + session_data['scanner'].socketio = socketio_conn + print(f"Restored socketio connection for session {session_id}") + else: + print(f"No socketio connection found for session {session_id}") + session_data['scanner'].socketio = None return session_data return None except Exception as e: print(f"ERROR: Failed to get session data for {session_id}: {e}") + import traceback + traceback.print_exc() return None def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: """ Serializes and saves session data back to Redis with updated TTL. + FIXED: Now preserves socketio connection during storage. Returns: bool: True if save was successful """ try: session_key = self._get_session_key(session_id) - serialized_data = pickle.dumps(session_data) + + # Create a deep copy to avoid modifying the original scanner object + session_data_to_save = copy.deepcopy(session_data) + + # Prepare scanner for storage if it exists + if 'scanner' in session_data_to_save and session_data_to_save['scanner']: + # FIXED: Preserve the original socketio connection before preparing for storage + original_socketio = session_data_to_save['scanner'].socketio + + session_data_to_save['scanner'] = self._prepare_scanner_for_storage( + session_data_to_save['scanner'], + session_id + ) + + # FIXED: If we had a socketio connection, make sure it's registered + if original_socketio and session_id not in self.active_socketio_connections: + self.register_socketio_connection(session_id, original_socketio) + + serialized_data = pickle.dumps(session_data_to_save) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) return result except Exception as e: print(f"ERROR: Failed to save session data for {session_id}: {e}") + import traceback + traceback.print_exc() return False def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: """ - Updates just the scanner object in a session with immediate persistence. + FIXED: Updates just the scanner object in a session with immediate persistence. + Now maintains socketio connection throughout the update process. Returns: bool: True if update was successful @@ -207,21 +306,27 @@ class SessionManager: try: session_data = self._get_session_data(session_id) if session_data: - # Ensure scanner has the session ID - scanner.session_id = session_id + # FIXED: Preserve socketio connection before preparing for storage + original_socketio = scanner.socketio + + # Prepare scanner for storage + scanner = self._prepare_scanner_for_storage(scanner, session_id) session_data['scanner'] = scanner session_data['last_activity'] = time.time() + # FIXED: Restore socketio connection after preparation + if original_socketio: + self.register_socketio_connection(session_id, original_socketio) + session_data['scanner'].socketio = original_socketio + # Immediately save to Redis for GUI updates success = self._save_session_data(session_id, session_data) if success: # Only log occasionally to reduce noise if hasattr(self, '_last_update_log'): if time.time() - self._last_update_log > 5: # Log every 5 seconds max - #print(f"Scanner state updated for session {session_id} (status: {scanner.status})") self._last_update_log = time.time() else: - #print(f"Scanner state updated for session {session_id} (status: {scanner.status})") self._last_update_log = time.time() else: print(f"WARNING: Failed to save scanner state for session {session_id}") @@ -231,6 +336,8 @@ class SessionManager: return False except Exception as e: print(f"ERROR: Failed to update scanner for session {session_id}: {e}") + import traceback + traceback.print_exc() return False def update_scanner_status(self, session_id: str, status: str) -> bool: @@ -263,7 +370,7 @@ class SessionManager: def get_session(self, session_id: str) -> Optional[Scanner]: """ - Get scanner instance for a session from Redis with session ID management. + FIXED: Get scanner instance for a session from Redis with proper socketio restoration. """ if not session_id: return None @@ -281,6 +388,15 @@ class SessionManager: if scanner: # Ensure the scanner can check the Redis-based stop signal scanner.session_id = session_id + + # FIXED: Restore socketio connection from our registry + socketio_conn = self.get_socketio_connection(session_id) + if socketio_conn: + scanner.socketio = socketio_conn + print(f"โœ“ Restored socketio connection for session {session_id}") + else: + scanner.socketio = None + print(f"โš ๏ธ No socketio connection found for session {session_id}") return scanner @@ -333,6 +449,12 @@ class SessionManager: # Wait a moment for graceful shutdown time.sleep(0.5) + # FIXED: Clean up socketio connection + with self.lock: + if session_id in self.active_socketio_connections: + del self.active_socketio_connections[session_id] + print(f"Cleaned up socketio connection for session {session_id}") + # Delete session data and stop signal from Redis session_key = self._get_session_key(session_id) stop_key = self._get_stop_signal_key(session_id) @@ -344,6 +466,8 @@ class SessionManager: except Exception as e: print(f"ERROR: Failed to terminate session {session_id}: {e}") + import traceback + traceback.print_exc() return False def _cleanup_loop(self) -> None: @@ -364,6 +488,12 @@ class SessionManager: self.redis_client.delete(stop_key) print(f"Cleaned up orphaned stop signal for session {session_id}") + # Also clean up socketio connection + with self.lock: + if session_id in self.active_socketio_connections: + del self.active_socketio_connections[session_id] + print(f"Cleaned up orphaned socketio for session {session_id}") + except Exception as e: print(f"Error in cleanup loop: {e}") @@ -387,14 +517,16 @@ class SessionManager: return { 'total_active_sessions': active_sessions, 'running_scans': running_scans, - 'total_stop_signals': len(stop_keys) + 'total_stop_signals': len(stop_keys), + 'active_socketio_connections': len(self.active_socketio_connections) } except Exception as e: print(f"ERROR: Failed to get statistics: {e}") return { 'total_active_sessions': 0, 'running_scans': 0, - 'total_stop_signals': 0 + 'total_stop_signals': 0, + 'active_socketio_connections': 0 } # Global session manager instance diff --git a/providers/base_provider.py b/providers/base_provider.py index 61a3bbe..108a8f6 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -15,6 +15,7 @@ class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. Now supports session-specific configuration and returns standardized ProviderResult objects. + FIXED: Enhanced pickle support to prevent weakref serialization errors. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): @@ -53,22 +54,57 @@ class BaseProvider(ABC): def __getstate__(self): """Prepare BaseProvider for pickling by excluding unpicklable objects.""" state = self.__dict__.copy() - # Exclude the unpickleable '_local' attribute (which holds the session) and stop event - unpicklable_attrs = ['_local', '_stop_event'] + + # Exclude unpickleable attributes that may contain weakrefs + unpicklable_attrs = [ + '_local', # Thread-local storage (contains requests.Session) + '_stop_event', # Threading event + 'logger', # Logger may contain weakrefs in handlers + ] + for attr in unpicklable_attrs: if attr in state: del state[attr] + + # Also handle any potential weakrefs in the config object + if 'config' in state and hasattr(state['config'], '__getstate__'): + # If config has its own pickle support, let it handle itself + pass + elif 'config' in state: + # Otherwise, ensure config doesn't contain unpicklable objects + try: + # Test if config can be pickled + import pickle + pickle.dumps(state['config']) + except (TypeError, AttributeError): + # If config can't be pickled, we'll recreate it during unpickling + state['_config_class'] = type(state['config']).__name__ + del state['config'] + return state def __setstate__(self, state): """Restore BaseProvider after unpickling by reconstructing threading objects.""" self.__dict__.update(state) - # Re-initialize the '_local' attribute and stop event + + # Re-initialize unpickleable attributes self._local = threading.local() self._stop_event = None + self.logger = get_forensic_logger() + + # Recreate config if it was removed during pickling + if not hasattr(self, 'config') and hasattr(self, '_config_class'): + if self._config_class == 'Config': + from config import config as global_config + self.config = global_config + elif self._config_class == 'SessionConfig': + from core.session_config import create_session_config + self.config = create_session_config() + del self._config_class @property def session(self): + """Get or create thread-local requests session.""" if not hasattr(self._local, 'session'): self._local.session = requests.Session() self._local.session.headers.update({ diff --git a/providers/correlation_provider.py b/providers/correlation_provider.py index 64ed854..7b56a3a 100644 --- a/providers/correlation_provider.py +++ b/providers/correlation_provider.py @@ -10,6 +10,7 @@ from core.graph_manager import NodeType, GraphManager class CorrelationProvider(BaseProvider): """ A provider that finds correlations between nodes in the graph. + FIXED: Enhanced pickle support to prevent weakref issues with graph references. """ def __init__(self, name: str = "correlation", session_config=None): @@ -38,6 +39,38 @@ class CorrelationProvider(BaseProvider): 'query_timestamp', ] + def __getstate__(self): + """ + FIXED: Prepare CorrelationProvider for pickling by excluding graph reference. + """ + state = super().__getstate__() + + # Remove graph reference to prevent circular dependencies and weakrefs + if 'graph' in state: + del state['graph'] + + # Also handle correlation_index which might contain complex objects + if 'correlation_index' in state: + # Clear correlation index as it will be rebuilt when needed + state['correlation_index'] = {} + + return state + + def __setstate__(self, state): + """ + FIXED: Restore CorrelationProvider after unpickling. + """ + super().__setstate__(state) + + # Re-initialize graph reference (will be set by scanner) + self.graph = None + + # Re-initialize correlation index + self.correlation_index = {} + + # Re-compile regex pattern + self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') + def get_name(self) -> str: """Return the provider name.""" return "correlation" @@ -79,13 +112,20 @@ class CorrelationProvider(BaseProvider): def _find_correlations(self, node_id: str) -> ProviderResult: """ Find correlations for a given node. + FIXED: Added safety checks to prevent issues when graph is None. """ result = ProviderResult() - # FIXED: Ensure self.graph is not None before proceeding. + + # FIXED: Ensure self.graph is not None before proceeding if not self.graph or not self.graph.graph.has_node(node_id): return result - node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) + try: + node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) + except Exception as e: + # If there's any issue accessing the graph, return empty result + print(f"Warning: Could not access graph for correlation analysis: {e}") + return result for attr in node_attributes: attr_name = attr.get('name') @@ -134,6 +174,7 @@ class CorrelationProvider(BaseProvider): if len(self.correlation_index[attr_value]['nodes']) > 1: self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result) + return result def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult): diff --git a/providers/dns_provider.py b/providers/dns_provider.py index 12a91ff..e380bca 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -11,6 +11,7 @@ class DNSProvider(BaseProvider): """ Provider for standard DNS resolution and reverse DNS lookups. Now returns standardized ProviderResult objects with IPv4 and IPv6 support. + FIXED: Enhanced pickle support to prevent resolver serialization issues. """ def __init__(self, name=None, session_config=None): @@ -28,19 +29,20 @@ class DNSProvider(BaseProvider): self.resolver.lifetime = 10 def __getstate__(self): - """Prepare the object for pickling.""" - state = self.__dict__.copy() + """Prepare the object for pickling by excluding resolver.""" + state = super().__getstate__() # Remove the unpickleable 'resolver' attribute if 'resolver' in state: del state['resolver'] return state def __setstate__(self, state): - """Restore the object after unpickling.""" - self.__dict__.update(state) + """Restore the object after unpickling by reconstructing resolver.""" + super().__setstate__(state) # Re-initialize the 'resolver' attribute self.resolver = resolver.Resolver() self.resolver.timeout = 5 + self.resolver.lifetime = 10 def get_name(self) -> str: """Return the provider name.""" @@ -121,10 +123,10 @@ class DNSProvider(BaseProvider): if _is_valid_domain(hostname): # Determine appropriate forward relationship type based on IP version if ip_version == 6: - relationship_type = 'dns_aaaa_record' + relationship_type = 'shodan_aaaa_record' record_prefix = 'AAAA' else: - relationship_type = 'dns_a_record' + relationship_type = 'shodan_a_record' record_prefix = 'A' # Add the relationship diff --git a/static/js/main.js b/static/js/main.js index f8ee132..b63b1a4 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -1,7 +1,7 @@ /** * Main application logic for DNSRecon web interface * Handles UI interactions, API communication, and data flow - * UPDATED: Now compatible with a strictly flat, unified data model for attributes. + * FIXED: Enhanced real-time WebSocket graph updates */ class DNSReconApp { @@ -17,6 +17,14 @@ class DNSReconApp { this.isScanning = false; this.lastGraphUpdate = null; + // FIXED: Add connection state tracking + this.isConnected = false; + this.reconnectAttempts = 0; + this.maxReconnectAttempts = 5; + + // FIXED: Track last graph data for debugging + this.lastGraphData = null; + this.init(); } @@ -45,22 +53,159 @@ class DNSReconApp { } initializeSocket() { - this.socket = io(); + console.log('๐Ÿ”Œ Initializing WebSocket connection...'); + + try { + this.socket = io({ + transports: ['websocket', 'polling'], + timeout: 10000, + reconnection: true, + reconnectionAttempts: 5, + reconnectionDelay: 2000 + }); - this.socket.on('connect', () => { - console.log('Connected to WebSocket server'); - this.updateConnectionStatus('idle'); - this.socket.emit('get_status'); - }); + this.socket.on('connect', () => { + console.log('โœ… WebSocket connected successfully'); + this.isConnected = true; + this.reconnectAttempts = 0; + this.updateConnectionStatus('idle'); + + console.log('๐Ÿ“ก Requesting initial status...'); + this.socket.emit('get_status'); + }); - this.socket.on('scan_update', (data) => { - if (data.status !== this.scanStatus) { - this.handleStatusChange(data.status, data.task_queue_size); - } - this.scanStatus = data.status; - this.updateStatusDisplay(data); - this.graphManager.updateGraph(data.graph); - }); + this.socket.on('disconnect', (reason) => { + console.log('โŒ WebSocket disconnected:', reason); + this.isConnected = false; + this.updateConnectionStatus('error'); + }); + + this.socket.on('connect_error', (error) => { + console.error('โŒ WebSocket connection error:', error); + this.reconnectAttempts++; + this.updateConnectionStatus('error'); + + if (this.reconnectAttempts >= 5) { + this.showError('WebSocket connection failed. Please refresh the page.'); + } + }); + + this.socket.on('reconnect', (attemptNumber) => { + console.log('โœ… WebSocket reconnected after', attemptNumber, 'attempts'); + this.isConnected = true; + this.reconnectAttempts = 0; + this.updateConnectionStatus('idle'); + this.socket.emit('get_status'); + }); + + // FIXED: Enhanced scan_update handler with detailed graph processing and debugging + this.socket.on('scan_update', (data) => { + console.log('๐Ÿ“จ WebSocket update received:', { + status: data.status, + target: data.target_domain, + progress: data.progress_percentage, + graphNodes: data.graph?.nodes?.length || 0, + graphEdges: data.graph?.edges?.length || 0, + timestamp: new Date().toISOString() + }); + + try { + // Handle status change + if (data.status !== this.scanStatus) { + console.log(`๐Ÿ“„ Status change: ${this.scanStatus} โ†’ ${data.status}`); + this.handleStatusChange(data.status, data.task_queue_size); + } + this.scanStatus = data.status; + + // Update status display + this.updateStatusDisplay(data); + + // FIXED: Always update graph if data is present and graph manager exists + if (data.graph && this.graphManager) { + console.log('๐Ÿ“Š Processing graph update:', { + nodes: data.graph.nodes?.length || 0, + edges: data.graph.edges?.length || 0, + hasNodes: Array.isArray(data.graph.nodes), + hasEdges: Array.isArray(data.graph.edges), + isInitialized: this.graphManager.isInitialized + }); + + // FIXED: Initialize graph manager if not already done + if (!this.graphManager.isInitialized) { + console.log('๐ŸŽฏ Initializing graph manager...'); + this.graphManager.initialize(); + } + + // FIXED: Force graph update and verify it worked + const previousNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0; + const previousEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0; + + console.log('๐Ÿ”„ Before update - Nodes:', previousNodeCount, 'Edges:', previousEdgeCount); + + // Store the data for debugging + this.lastGraphData = data.graph; + + // Update the graph + this.graphManager.updateGraph(data.graph); + this.lastGraphUpdate = Date.now(); + + // Verify the update worked + const newNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0; + const newEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0; + + console.log('๐Ÿ”„ After update - Nodes:', newNodeCount, 'Edges:', newEdgeCount); + + if (newNodeCount !== data.graph.nodes.length || newEdgeCount !== data.graph.edges.length) { + console.warn('โš ๏ธ Graph update mismatch!', { + expectedNodes: data.graph.nodes.length, + actualNodes: newNodeCount, + expectedEdges: data.graph.edges.length, + actualEdges: newEdgeCount + }); + + // Force a complete rebuild if there's a mismatch + console.log('๐Ÿ”ง Force rebuilding graph...'); + this.graphManager.clear(); + this.graphManager.updateGraph(data.graph); + } + + console.log('โœ… Graph updated successfully'); + + // FIXED: Force network redraw if we're using vis.js + if (this.graphManager.network) { + try { + this.graphManager.network.redraw(); + console.log('๐ŸŽจ Network redrawn'); + } catch (redrawError) { + console.warn('โš ๏ธ Network redraw failed:', redrawError); + } + } + + } else { + if (!data.graph) { + console.log('โš ๏ธ No graph data in WebSocket update'); + } + if (!this.graphManager) { + console.log('โš ๏ธ Graph manager not available'); + } + } + + } catch (error) { + console.error('โŒ Error processing WebSocket update:', error); + console.error('Update data:', data); + console.error('Stack trace:', error.stack); + } + }); + + this.socket.on('error', (error) => { + console.error('โŒ WebSocket error:', error); + this.showError('WebSocket communication error'); + }); + + } catch (error) { + console.error('โŒ Failed to initialize WebSocket:', error); + this.showError('Failed to establish real-time connection'); + } } /** @@ -280,12 +425,36 @@ class DNSReconApp { } /** - * Initialize graph visualization + * FIXED: Initialize graph visualization with enhanced debugging */ initializeGraph() { try { console.log('Initializing graph manager...'); this.graphManager = new GraphManager('network-graph'); + + // FIXED: Add debugging hooks to graph manager + if (this.graphManager) { + // Override updateGraph to add debugging + const originalUpdateGraph = this.graphManager.updateGraph.bind(this.graphManager); + this.graphManager.updateGraph = (graphData) => { + console.log('๐Ÿ”ง GraphManager.updateGraph called with:', { + nodes: graphData?.nodes?.length || 0, + edges: graphData?.edges?.length || 0, + timestamp: new Date().toISOString() + }); + + const result = originalUpdateGraph(graphData); + + console.log('๐Ÿ”ง GraphManager.updateGraph completed, network state:', { + networkExists: !!this.graphManager.network, + nodeDataSetLength: this.graphManager.nodes?.length || 0, + edgeDataSetLength: this.graphManager.edges?.length || 0 + }); + + return result; + }; + } + console.log('Graph manager initialized successfully'); } catch (error) { console.error('Failed to initialize graph manager:', error); @@ -305,7 +474,6 @@ class DNSReconApp { console.log(`Target: "${target}", Max depth: ${maxDepth}`); - // Validation if (!target) { console.log('Validation failed: empty target'); this.showError('Please enter a target domain or IP'); @@ -320,6 +488,19 @@ class DNSReconApp { return; } + // FIXED: Ensure WebSocket connection before starting scan + if (!this.isConnected) { + console.log('WebSocket not connected, attempting to connect...'); + this.socket.connect(); + + // Wait a moment for connection + await new Promise(resolve => setTimeout(resolve, 1000)); + + if (!this.isConnected) { + this.showWarning('WebSocket connection not established. Updates may be delayed.'); + } + } + console.log('Validation passed, setting UI state to scanning...'); this.setUIState('scanning'); this.showInfo('Starting reconnaissance scan...'); @@ -337,16 +518,28 @@ class DNSReconApp { if (response.success) { this.currentSessionId = response.scan_id; - this.showSuccess('Reconnaissance scan started successfully'); + this.showSuccess('Reconnaissance scan started - watching for real-time updates'); - if (clearGraph) { + if (clearGraph && this.graphManager) { + console.log('๐Ÿงน Clearing graph for new scan'); this.graphManager.clear(); } - console.log(`Scan started for ${target} with depth ${maxDepth}`); + console.log(`โœ… Scan started for ${target} with depth ${maxDepth}`); - // Request initial status update via WebSocket - this.socket.emit('get_status'); + // FIXED: Immediately start listening for updates + if (this.socket && this.isConnected) { + console.log('๐Ÿ“ก Requesting initial status update...'); + this.socket.emit('get_status'); + + // Set up periodic status requests as backup (every 5 seconds during scan) + /*this.statusRequestInterval = setInterval(() => { + if (this.isScanning && this.socket && this.isConnected) { + console.log('๐Ÿ“ก Periodic status request...'); + this.socket.emit('get_status'); + } + }, 5000);*/ + } } else { throw new Error(response.error || 'Failed to start scan'); @@ -358,26 +551,34 @@ class DNSReconApp { this.setUIState('idle'); } } - /** - * Scan stop with immediate UI feedback - */ + + // FIXED: Enhanced stop scan with interval cleanup async stopScan() { try { console.log('Stopping scan...'); - // Immediately disable stop button and show stopping state + // Clear status request interval + /*if (this.statusRequestInterval) { + clearInterval(this.statusRequestInterval); + this.statusRequestInterval = null; + }*/ + if (this.elements.stopScan) { this.elements.stopScan.disabled = true; this.elements.stopScan.innerHTML = '[STOPPING]Stopping...'; } - // Show immediate feedback this.showInfo('Stopping scan...'); const response = await this.apiCall('/api/scan/stop', 'POST'); if (response.success) { this.showSuccess('Scan stop requested'); + + // Request final status update + if (this.socket && this.isConnected) { + setTimeout(() => this.socket.emit('get_status'), 500); + } } else { throw new Error(response.error || 'Failed to stop scan'); } @@ -386,7 +587,6 @@ class DNSReconApp { console.error('Failed to stop scan:', error); this.showError(`Failed to stop scan: ${error.message}`); - // Re-enable stop button on error if (this.elements.stopScan) { this.elements.stopScan.disabled = false; this.elements.stopScan.innerHTML = '[STOP]Terminate Scan'; @@ -543,23 +743,24 @@ class DNSReconApp { } /** - * Update graph from server + * FIXED: Update graph from server with enhanced debugging */ async updateGraph() { try { - console.log('Updating graph...'); + console.log('Updating graph via API call...'); const response = await this.apiCall('/api/graph'); if (response.success) { const graphData = response.graph; - console.log('Graph data received:'); + console.log('Graph data received from API:'); console.log('- Nodes:', graphData.nodes ? graphData.nodes.length : 0); console.log('- Edges:', graphData.edges ? graphData.edges.length : 0); // FIXED: Always update graph, even if empty - let GraphManager handle placeholder if (this.graphManager) { + console.log('๐Ÿ”ง Calling GraphManager.updateGraph from API response...'); this.graphManager.updateGraph(graphData); this.lastGraphUpdate = Date.now(); @@ -568,6 +769,8 @@ class DNSReconApp { if (this.elements.relationshipsDisplay) { this.elements.relationshipsDisplay.textContent = edgeCount; } + + console.log('โœ… Manual graph update completed'); } } else { console.error('Graph update failed:', response); @@ -663,12 +866,12 @@ class DNSReconApp { * @param {string} newStatus - New scan status */ handleStatusChange(newStatus, task_queue_size) { - console.log(`=== STATUS CHANGE: ${this.scanStatus} -> ${newStatus} ===`); + console.log(`๐Ÿ“„ Status change handler: ${this.scanStatus} โ†’ ${newStatus}`); switch (newStatus) { case 'running': this.setUIState('scanning', task_queue_size); - this.showSuccess('Scan is running'); + this.showSuccess('Scan is running - updates in real-time'); this.updateConnectionStatus('active'); break; @@ -677,8 +880,19 @@ class DNSReconApp { this.showSuccess('Scan completed successfully'); this.updateConnectionStatus('completed'); this.loadProviders(); - // Force a final graph update - console.log('Scan completed - forcing final graph update'); + console.log('โœ… Scan completed - requesting final graph update'); + // Request final status to ensure we have the complete graph + setTimeout(() => { + if (this.socket && this.isConnected) { + this.socket.emit('get_status'); + } + }, 1000); + + // Clear status request interval + /*if (this.statusRequestInterval) { + clearInterval(this.statusRequestInterval); + this.statusRequestInterval = null; + }*/ break; case 'failed': @@ -686,6 +900,12 @@ class DNSReconApp { this.showError('Scan failed'); this.updateConnectionStatus('error'); this.loadProviders(); + + // Clear status request interval + /*if (this.statusRequestInterval) { + clearInterval(this.statusRequestInterval); + this.statusRequestInterval = null; + }*/ break; case 'stopped': @@ -693,11 +913,23 @@ class DNSReconApp { this.showSuccess('Scan stopped'); this.updateConnectionStatus('stopped'); this.loadProviders(); + + // Clear status request interval + if (this.statusRequestInterval) { + clearInterval(this.statusRequestInterval); + this.statusRequestInterval = null; + } break; case 'idle': this.setUIState('idle', task_queue_size); this.updateConnectionStatus('idle'); + + // Clear status request interval + /*if (this.statusRequestInterval) { + clearInterval(this.statusRequestInterval); + this.statusRequestInterval = null; + }*/ break; default: @@ -749,6 +981,7 @@ class DNSReconApp { if (this.graphManager) { this.graphManager.isScanning = true; } + if (this.elements.startScan) { this.elements.startScan.disabled = true; this.elements.startScan.classList.add('loading'); @@ -776,6 +1009,7 @@ class DNSReconApp { if (this.graphManager) { this.graphManager.isScanning = false; } + if (this.elements.startScan) { this.elements.startScan.disabled = !isQueueEmpty; this.elements.startScan.classList.remove('loading'); @@ -1018,7 +1252,7 @@ class DNSReconApp { } else { // API key not configured - ALWAYS show input field const statusClass = info.enabled ? 'enabled' : 'api-key-required'; - const statusText = info.enabled ? 'โ—‹ Ready for API Key' : 'โš ๏ธ API Key Required'; + const statusText = info.enabled ? 'โ—ฏ Ready for API Key' : 'โš ๏ธ API Key Required'; inputGroup.innerHTML = `
@@ -2000,8 +2234,8 @@ class DNSReconApp { */ getNodeTypeIcon(nodeType) { const icons = { - 'domain': '๐ŸŒ', - 'ip': '๐Ÿ“', + 'domain': '๐ŸŒ', + 'ip': '๐Ÿ”ข', 'asn': '๐Ÿข', 'large_entity': '๐Ÿ“ฆ', 'correlation_object': '๐Ÿ”—'