From 03c52abd1b4585ef5d8bf8beaee0dfbbec3c7b71 Mon Sep 17 00:00:00 2001 From: overcuriousity Date: Fri, 12 Sep 2025 23:54:06 +0200 Subject: [PATCH] it --- app.py | 148 ++++++++++++---- config.py | 2 +- core/scanner.py | 326 ++++++++++++++++++++++++++-------- core/session_manager.py | 333 ++++++++++++++++++++++++++++++----- providers/base_provider.py | 130 ++++++++++---- providers/crtsh_provider.py | 24 +-- providers/shodan_provider.py | 6 +- static/js/main.js | 73 ++++++-- 8 files changed, 819 insertions(+), 223 deletions(-) diff --git a/app.py b/app.py index cc42a21..d50ef1b 100644 --- a/app.py +++ b/app.py @@ -20,28 +20,20 @@ app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2) # 2 hour session def get_user_scanner(): """ - Get or create scanner instance for current user session with enhanced debugging. - - Returns: - Tuple of (session_id, scanner_instance) + Enhanced user scanner retrieval with better error handling and debugging. """ # Get current Flask session info for debugging current_flask_session_id = session.get('dnsrecon_session_id') client_ip = request.remote_addr user_agent = request.headers.get('User-Agent', '')[:100] # Truncate for logging - print("=== SESSION DEBUG ===") - print(f"Client IP: {client_ip}") - print(f"User Agent: {user_agent}") - print(f"Flask Session ID: {current_flask_session_id}") - print(f"Flask Session Keys: {list(session.keys())}") - # Try to get existing session if current_flask_session_id: existing_scanner = session_manager.get_session(current_flask_session_id) if existing_scanner: - print(f"Using existing session: {current_flask_session_id}") print(f"Scanner status: {existing_scanner.status}") + # Ensure session ID is set + existing_scanner.session_id = current_flask_session_id return current_flask_session_id, existing_scanner else: print(f"Session {current_flask_session_id} not found in session manager") @@ -51,17 +43,23 @@ def get_user_scanner(): new_session_id = session_manager.create_session() new_scanner = session_manager.get_session(new_session_id) + if not new_scanner: + print(f"ERROR: Failed to retrieve newly created session {new_session_id}") + raise Exception("Failed to create new scanner session") + # Store in Flask session session['dnsrecon_session_id'] = new_session_id session.permanent = True + # Ensure session ID is set on scanner + new_scanner.session_id = new_session_id + print(f"Created new session: {new_session_id}") print(f"New scanner status: {new_scanner.status}") print("=== END SESSION DEBUG ===") return new_session_id, new_scanner - @app.route('/') def index(): """Serve the main web interface.""" @@ -71,8 +69,7 @@ def index(): @app.route('/api/scan/start', methods=['POST']) def start_scan(): """ - Start a new reconnaissance scan for the current user session. - Enhanced with better error handling and debugging. + Start a new reconnaissance scan with immediate GUI feedback. """ print("=== API: /api/scan/start called ===") @@ -111,9 +108,13 @@ def start_scan(): print("Validation passed, getting user scanner...") - # Get user-specific scanner with enhanced debugging + # Get user-specific scanner user_session_id, scanner = get_user_scanner() - scanner.user_session_id = user_session_id + + # Ensure session ID is properly set + if not scanner.session_id: + scanner.session_id = user_session_id + print(f"Using session: {user_session_id}") print(f"Scanner object ID: {id(scanner)}") @@ -121,6 +122,8 @@ def start_scan(): print(f"Calling start_scan on scanner {id(scanner)}...") success = scanner.start_scan(target_domain, max_depth, clear_graph=clear_graph) + # Immediately update session state regardless of success + session_manager.update_session_scanner(user_session_id, scanner) if success: scan_session_id = scanner.logger.session_id @@ -130,6 +133,7 @@ def start_scan(): 'message': 'Scan started successfully', 'scan_id': scan_session_id, 'user_session_id': user_session_id, + 'scanner_status': scanner.status, 'debug_info': { 'scanner_object_id': id(scanner), 'scanner_status': scanner.status @@ -160,9 +164,10 @@ def start_scan(): 'error': f'Internal server error: {str(e)}' }), 500 + @app.route('/api/scan/stop', methods=['POST']) def stop_scan(): - """Stop the current scan for the user session.""" + """Stop the current scan with immediate GUI feedback.""" print("=== API: /api/scan/stop called ===") try: @@ -170,19 +175,37 @@ def stop_scan(): user_session_id, scanner = get_user_scanner() print(f"Stopping scan for session: {user_session_id}") + if not scanner: + return jsonify({ + 'success': False, + 'error': 'No scanner found for session' + }), 404 + + # Ensure session ID is set + if not scanner.session_id: + scanner.session_id = user_session_id + + # Use the enhanced stop mechanism success = scanner.stop_scan() - if success: - return jsonify({ - 'success': True, - 'message': 'Scan stop requested', - 'user_session_id': user_session_id - }) - else: - return jsonify({ - 'success': True, - 'message': 'No active scan to stop for this session' - }) + # Also set the Redis stop signal directly for extra reliability + session_manager.set_stop_signal(user_session_id) + + # Force immediate status update + session_manager.update_scanner_status(user_session_id, 'stopped') + + # Update the full scanner state + session_manager.update_session_scanner(user_session_id, scanner) + + print(f"Stop scan completed. Success: {success}, Scanner status: {scanner.status}") + + return jsonify({ + 'success': True, + 'message': 'Scan stop requested - termination initiated', + 'user_session_id': user_session_id, + 'scanner_status': scanner.status, + 'stop_method': 'enhanced_cross_process' + }) except Exception as e: print(f"ERROR: Exception in stop_scan endpoint: {e}") @@ -195,14 +218,44 @@ def stop_scan(): @app.route('/api/scan/status', methods=['GET']) def get_scan_status(): - """Get current scan status and progress for the user session.""" + """Get current scan status with enhanced error handling.""" try: # Get user-specific scanner user_session_id, scanner = get_user_scanner() + if not scanner: + # Return default idle status if no scanner + return jsonify({ + 'success': True, + 'status': { + 'status': 'idle', + 'target_domain': None, + 'current_depth': 0, + 'max_depth': 0, + 'current_indicator': '', + 'total_indicators_found': 0, + 'indicators_processed': 0, + 'progress_percentage': 0.0, + 'enabled_providers': [], + 'graph_statistics': {}, + 'user_session_id': user_session_id + } + }) + + # Ensure session ID is set + if not scanner.session_id: + scanner.session_id = user_session_id + status = scanner.get_scan_status() status['user_session_id'] = user_session_id + # Additional debug info + status['debug_info'] = { + 'scanner_object_id': id(scanner), + 'session_id_set': bool(scanner.session_id), + 'has_scan_thread': bool(scanner.scan_thread and scanner.scan_thread.is_alive()) + } + return jsonify({ 'success': True, 'status': status @@ -213,17 +266,42 @@ def get_scan_status(): traceback.print_exc() return jsonify({ 'success': False, - 'error': f'Internal server error: {str(e)}' + 'error': f'Internal server error: {str(e)}', + 'fallback_status': { + 'status': 'error', + 'target_domain': None, + 'current_depth': 0, + 'max_depth': 0, + 'progress_percentage': 0.0 + } }), 500 + @app.route('/api/graph', methods=['GET']) def get_graph_data(): - """Get current graph data for visualization for the user session.""" + """Get current graph data with enhanced error handling.""" try: # Get user-specific scanner user_session_id, scanner = get_user_scanner() + if not scanner: + # Return empty graph if no scanner + return jsonify({ + 'success': True, + 'graph': { + 'nodes': [], + 'edges': [], + 'statistics': { + 'node_count': 0, + 'edge_count': 0, + 'creation_time': datetime.now(timezone.utc).isoformat(), + 'last_modified': datetime.now(timezone.utc).isoformat() + } + }, + 'user_session_id': user_session_id + }) + graph_data = scanner.get_graph_data() return jsonify({ 'success': True, @@ -236,10 +314,16 @@ def get_graph_data(): traceback.print_exc() return jsonify({ 'success': False, - 'error': f'Internal server error: {str(e)}' + 'error': f'Internal server error: {str(e)}', + 'fallback_graph': { + 'nodes': [], + 'edges': [], + 'statistics': {'node_count': 0, 'edge_count': 0} + } }), 500 + @app.route('/api/export', methods=['GET']) def export_results(): """Export complete scan results as downloadable JSON for the user session.""" diff --git a/config.py b/config.py index 497a9a9..df4a06c 100644 --- a/config.py +++ b/config.py @@ -19,7 +19,7 @@ class Config: # Default settings self.default_recursion_depth = 2 - self.default_timeout = 30 + self.default_timeout = 10 self.max_concurrent_requests = 5 self.large_entity_threshold = 100 diff --git a/core/scanner.py b/core/scanner.py index 9ec0eda..3b9f9f3 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -2,8 +2,9 @@ import threading import traceback +import time from typing import List, Set, Dict, Any, Tuple -from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError +from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future from collections import defaultdict from datetime import datetime, timezone @@ -27,7 +28,7 @@ class ScanStatus: class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. - REFACTORED: Simplified recursion with forensic provider state tracking. + Enhanced with reliable cross-process termination capabilities. """ def __init__(self, session_config=None): @@ -49,6 +50,7 @@ class Scanner: self.max_depth = 2 self.stop_event = threading.Event() self.scan_thread = None + self.session_id = None # Will be set by session manager # Scanning progress tracking self.total_indicators_found = 0 @@ -82,6 +84,42 @@ class Scanner: traceback.print_exc() raise + def _is_stop_requested(self) -> bool: + """ + Check if stop is requested using both local and Redis-based signals. + This ensures reliable termination across process boundaries. + """ + # Check local threading event first (fastest) + if self.stop_event.is_set(): + return True + + # Check Redis-based stop signal if session ID is available + if self.session_id: + try: + from core.session_manager import session_manager + return session_manager.is_stop_requested(self.session_id) + except Exception as e: + print(f"Error checking Redis stop signal: {e}") + # Fall back to local event + return self.stop_event.is_set() + + return False + + def _set_stop_signal(self) -> None: + """ + Set stop signal both locally and in Redis. + """ + # Set local event + self.stop_event.set() + + # Set Redis signal if session ID is available + if self.session_id: + try: + from core.session_manager import session_manager + session_manager.set_stop_signal(self.session_id) + except Exception as e: + print(f"Error setting Redis stop signal: {e}") + def __getstate__(self): """Prepare object for pickling by excluding unpicklable attributes.""" state = self.__dict__.copy() @@ -159,8 +197,9 @@ class Scanner: print("Session configuration updated") def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: - """Start a new reconnaissance scan with forensic tracking.""" + """Start a new reconnaissance scan with immediate GUI feedback.""" print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") + print(f"Session ID: {self.session_id}") print(f"Initial scanner status: {self.status}") # Clean up previous scan thread if needed @@ -172,11 +211,13 @@ class Scanner: if self.scan_thread.is_alive(): print("ERROR: The previous scan thread is unresponsive and could not be stopped.") self.status = ScanStatus.FAILED + self._update_session_state() return False print("Previous scan thread terminated successfully.") # Reset state for new scan self.status = ScanStatus.IDLE + self._update_session_state() # Update GUI immediately print("Scanner state is now clean for a new scan.") try: @@ -191,11 +232,20 @@ class Scanner: self.current_target = target_domain.lower().strip() self.max_depth = max_depth self.current_depth = 0 + + # Clear both local and Redis stop signals self.stop_event.clear() + if self.session_id: + from core.session_manager import session_manager + session_manager.clear_stop_signal(self.session_id) + self.total_indicators_found = 0 self.indicators_processed = 0 self.current_indicator = self.current_target + # Update GUI with scan preparation + self._update_session_state() + # Start new forensic session print(f"Starting new forensic session for scanner {id(self)}...") self.logger = new_session() @@ -216,16 +266,20 @@ class Scanner: print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() self.status = ScanStatus.FAILED + self._update_session_state() # Update failed status immediately return False def _execute_scan(self, target_domain: str, max_depth: int) -> None: - """Execute the reconnaissance scan with simplified recursion and forensic tracking.""" + """Execute the reconnaissance scan with frequent state updates for GUI.""" print(f"_execute_scan started for {target_domain} with depth {max_depth}") self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_targets = set() try: self.status = ScanStatus.RUNNING + # Immediate status update for GUI + self._update_session_state() + enabled_providers = [provider.get_name() for provider in self.providers] self.logger.log_scan_start(target_domain, max_depth, enabled_providers) self.graph.add_node(target_domain, NodeType.DOMAIN) @@ -235,11 +289,13 @@ class Scanner: all_discovered_targets = {target_domain} for depth in range(max_depth + 1): - if self.stop_event.is_set(): + if self._is_stop_requested(): print(f"Stop requested at depth {depth}") break self.current_depth = depth + self._update_session_state() + targets_to_process = current_level_targets - processed_targets if not targets_to_process: print("No new targets to process at this level.") @@ -247,8 +303,9 @@ class Scanner: print(f"Processing depth level {depth} with {len(targets_to_process)} new targets") self.total_indicators_found += len(targets_to_process) + self._update_session_state() - target_results = self._process_targets_concurrent_forensic( + target_results = self._process_targets_sequential_with_stop_checks( targets_to_process, processed_targets, all_discovered_targets, depth ) processed_targets.update(targets_to_process) @@ -256,31 +313,57 @@ class Scanner: next_level_targets = set() for _target, new_targets in target_results: all_discovered_targets.update(new_targets) + # This is the critical change: add all new targets to the next level next_level_targets.update(new_targets) - current_level_targets = next_level_targets + # Filter out already processed targets before the next iteration + current_level_targets = next_level_targets - processed_targets + + if self._is_stop_requested(): + print(f"Stop requested after processing depth {depth}") + break except Exception as e: print(f"ERROR: Scan execution failed with error: {e}") traceback.print_exc() self.status = ScanStatus.FAILED + self._update_session_state() # Update failed status immediately self.logger.logger.error(f"Scan failed: {e}") finally: - if self.stop_event.is_set(): + if self._is_stop_requested(): self.status = ScanStatus.STOPPED else: self.status = ScanStatus.COMPLETED + + # Final status update for GUI + self._update_session_state() + self.logger.log_scan_complete() - self.executor.shutdown(wait=False, cancel_futures=True) + if self.executor: + self.executor.shutdown(wait=False, cancel_futures=True) stats = self.graph.get_statistics() print("Final scan statistics:") print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Targets processed: {len(processed_targets)}") + def _update_session_state(self) -> None: + """ + Update the scanner state in Redis for GUI updates. + This ensures the web interface sees real-time updates. + """ + if self.session_id: + try: + from core.session_manager import session_manager + success = session_manager.update_session_scanner(self.session_id, self) + if not success: + print(f"WARNING: Failed to update session state for {self.session_id}") + except Exception as e: + print(f"ERROR: Failed to update session state: {e}") + def _initialize_provider_states(self, target: str) -> None: """Initialize provider states for forensic tracking.""" - if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() + if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] @@ -292,7 +375,6 @@ class Scanner: def _should_recurse_on_target(self, target: str, processed_targets: Set[str], all_discovered: Set[str]) -> bool: """ Simplified recursion logic: only recurse on valid IPs and domains that haven't been processed. - FORENSIC PRINCIPLE: Clear, simple rules for what gets recursed. """ # Don't recurse on already processed targets if target in processed_targets: @@ -318,51 +400,129 @@ class Scanner: return True return False - def _process_targets_concurrent_forensic(self, targets: Set[str], processed_targets: Set[str], - all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]: - """Process multiple targets concurrently with forensic provider state tracking.""" + def _process_targets_sequential_with_stop_checks(self, targets: Set[str], processed_targets: Set[str], + all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]: + """ + Process targets with controlled concurrency for both responsiveness and proper completion. + Balances termination responsiveness with avoiding race conditions. + """ results = [] targets_to_process = targets - processed_targets if not targets_to_process: return results - print(f"Processing {len(targets_to_process)} targets concurrently with forensic tracking") + print(f"Processing {len(targets_to_process)} targets with controlled concurrency") - future_to_target = { - self.executor.submit(self._query_providers_forensic, target, current_depth): target - for target in targets_to_process - } - - for future in as_completed(future_to_target): - if self.stop_event.is_set(): - future.cancel() - continue - target = future_to_target[future] - try: - new_targets = future.result() - results.append((target, new_targets)) - self.indicators_processed += 1 - print(f"Completed processing target: {target} (found {len(new_targets)} new targets)") - except (Exception, CancelledError) as e: - print(f"Error processing target {target}: {e}") - self._log_target_processing_error(target, str(e)) + target_list = list(targets_to_process) + active_futures: Dict[Future, str] = {} + target_index = 0 + last_gui_update = time.time() - # Add this block to save the state to Redis - from core.session_manager import session_manager - if hasattr(self, 'user_session_id'): - session_manager.update_session_scanner(self.user_session_id, self) + while target_index < len(target_list) or active_futures: + # Check stop signal before any new work + if self._is_stop_requested(): + print("Stop requested - canceling active futures and exiting") + for future in list(active_futures.keys()): + future.cancel() + break + + # Submit new futures up to max_workers limit (controlled concurrency) + while len(active_futures) < self.max_workers and target_index < len(target_list): + if self._is_stop_requested(): + break + + target = target_list[target_index] + self.current_indicator = target + print(f"Submitting target {target_index + 1}/{len(target_list)}: {target}") + + future = self.executor.submit(self._query_providers_forensic, target, current_depth) + active_futures[future] = target + target_index += 1 + + # Update GUI periodically + current_time = time.time() + if target_index % 2 == 0 or (current_time - last_gui_update) > 2.0: + self._update_session_state() + last_gui_update = current_time + + # Wait for at least one future to complete (but don't wait forever) + if active_futures: + try: + # Wait for the first completion with reasonable timeout + completed_future = next(as_completed(active_futures.keys(), timeout=15.0)) + + target = active_futures[completed_future] + try: + new_targets = completed_future.result() + results.append((target, new_targets)) + self.indicators_processed += 1 + print(f"Completed processing target: {target} (found {len(new_targets)} new targets)") + + # Update GUI after each completion + current_time = time.time() + if (current_time - last_gui_update) > 1.0: + self._update_session_state() + last_gui_update = current_time + + except Exception as e: + print(f"Error processing target {target}: {e}") + self._log_target_processing_error(target, str(e)) + + # Remove the completed future + del active_futures[completed_future] + + except StopIteration: + # No futures completed within timeout - check stop signal and continue + if self._is_stop_requested(): + print("Stop requested during timeout - canceling futures") + for future in list(active_futures.keys()): + future.cancel() + break + # Continue loop to wait for completions + + except Exception as e: + # as_completed timeout or other error + if self._is_stop_requested(): + print("Stop requested during future waiting") + for future in list(active_futures.keys()): + future.cancel() + break + + # Check if any futures are actually done (in case of timeout exception) + completed_futures = [f for f in active_futures.keys() if f.done()] + for completed_future in completed_futures: + target = active_futures[completed_future] + try: + new_targets = completed_future.result() + results.append((target, new_targets)) + self.indicators_processed += 1 + print(f"Completed processing target: {target} (found {len(new_targets)} new targets)") + except Exception as ex: + print(f"Error processing target {target}: {ex}") + self._log_target_processing_error(target, str(ex)) + + del active_futures[completed_future] + + print(f"Completed processing all targets at depth {current_depth}") + + # Final state update + self._update_session_state() return results def _query_providers_forensic(self, target: str, current_depth: int) -> Set[str]: """ - Query providers for a target with forensic state tracking and simplified recursion. - REFACTORED: Simplified logic with complete forensic audit trail. + Query providers for a target with enhanced stop signal checking. """ is_ip = _is_valid_ip(target) target_type = NodeType.IP if is_ip else NodeType.DOMAIN print(f"Querying providers for {target_type.value}: {target} at depth {current_depth}") + # Early stop check + if self._is_stop_requested(): + print(f"Stop requested before querying providers for {target}") + return set() + # Initialize node and provider states self.graph.add_node(target, target_type) self._initialize_provider_states(target) @@ -377,34 +537,27 @@ class Scanner: self._log_no_eligible_providers(target, is_ip) return new_targets - # Query each eligible provider with forensic tracking - with ThreadPoolExecutor(max_workers=len(eligible_providers)) as provider_executor: - future_to_provider = { - provider_executor.submit(self._query_single_provider_forensic, provider, target, is_ip, current_depth): provider - for provider in eligible_providers - } + # Query each eligible provider sequentially with stop checks + for provider in eligible_providers: + if self._is_stop_requested(): + print(f"Stop requested while querying providers for {target}") + break - for future in as_completed(future_to_provider): - if self.stop_event.is_set(): - future.cancel() - continue - - provider = future_to_provider[future] - try: - provider_results = future.result() - if provider_results: - discovered_targets = self._process_provider_results_forensic( - target, provider, provider_results, target_metadata, current_depth - ) - new_targets.update(discovered_targets) - except (Exception, CancelledError) as e: - self._log_provider_error(target, provider.get_name(), str(e)) + try: + provider_results = self._query_single_provider_forensic(provider, target, is_ip, current_depth) + if provider_results and not self._is_stop_requested(): + discovered_targets = self._process_provider_results_forensic( + target, provider, provider_results, target_metadata, current_depth + ) + new_targets.update(discovered_targets) + except Exception as e: + self._log_provider_error(target, provider.get_name(), str(e)) + # Update node metadata for node_id, metadata_dict in target_metadata.items(): if self.graph.graph.has_node(node_id): node_is_ip = _is_valid_ip(node_id) node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN - # This call updates the existing node with the new metadata self.graph.add_node(node_id, node_type_to_add, metadata=metadata_dict) return new_targets @@ -428,7 +581,7 @@ class Scanner: def _already_queried_provider(self, target: str, provider_name: str) -> bool: """Check if we already queried a provider for a target.""" - if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() + if not self.graph.graph.has_node(target): return False node_data = self.graph.graph.nodes[target] @@ -436,10 +589,15 @@ class Scanner: return provider_name in provider_states def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List: - """Query a single provider with complete forensic logging.""" + """Query a single provider with stop signal checking.""" provider_name = provider.get_name() start_time = datetime.now(timezone.utc) + # Check stop signal before querying + if self._is_stop_requested(): + print(f"Stop requested before querying {provider_name} for {target}") + return [] + print(f"Querying {provider_name} for {target}") # Log attempt @@ -452,6 +610,11 @@ class Scanner: else: results = provider.query_domain(target) + # Check stop signal after querying + if self._is_stop_requested(): + print(f"Stop requested after querying {provider_name} for {target}") + return [] + # Track successful state self._update_provider_state(target, provider_name, 'success', len(results), None, start_time) @@ -467,7 +630,7 @@ class Scanner: def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: str, start_time: datetime) -> None: """Update provider state in node metadata for forensic tracking.""" - if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() + if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] @@ -489,10 +652,15 @@ class Scanner: def _process_provider_results_forensic(self, target: str, provider, results: List, target_metadata: Dict, current_depth: int) -> Set[str]: - """Process provider results with large entity detection and forensic logging.""" + """Process provider results with large entity detection and stop signal checking.""" provider_name = provider.get_name() discovered_targets = set() + # Check for stop signal before processing results + if self._is_stop_requested(): + print(f"Stop requested before processing results from {provider_name} for {target}") + return discovered_targets + # Check for large entity threshold per provider if len(results) > self.config.large_entity_threshold: print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}") @@ -503,8 +671,10 @@ class Scanner: # Process each relationship dns_records_to_create = {} - for source, rel_target, rel_type, confidence, raw_data in results: - if self.stop_event.is_set(): + for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results): + # Check stop signal periodically during result processing + if i % 10 == 0 and self._is_stop_requested(): + print(f"Stop requested while processing results from {provider_name} for {target}") break # Enhanced forensic logging for each relationship @@ -539,7 +709,7 @@ class Scanner: print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})") discovered_targets.add(rel_target) - # *** NEW: Enrich the newly discovered domain *** + # Enrich the newly discovered domain self._collect_node_metadata_forensic(rel_target, provider_name, rel_type, source, raw_data, target_metadata[rel_target]) else: @@ -691,25 +861,24 @@ class Scanner: self.logger.logger.warning(f"No eligible providers for {target_type}: {target}") def stop_scan(self) -> bool: - """Request immediate scan termination with forensic logging.""" + """Request immediate scan termination with immediate GUI feedback.""" try: - if not self.scan_thread or not self.scan_thread.is_alive(): - print("No active scan thread to stop.") - if self.status == ScanStatus.RUNNING: - self.status = ScanStatus.STOPPED - return False - print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") self.logger.logger.info("Scan termination requested by user") + # Set both local and Redis stop signals + self._set_stop_signal() self.status = ScanStatus.STOPPED - self.stop_event.set() + + # Immediately update GUI with stopped status + self._update_session_state() + # Cancel executor futures if running if self.executor: print("Shutting down executor with immediate cancellation...") self.executor.shutdown(wait=False, cancel_futures=True) - print("Termination signal sent. The scan thread will stop shortly.") + print("Termination signals sent. The scan will stop as soon as possible.") return True except Exception as e: @@ -774,7 +943,8 @@ class Scanner: 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, 'enabled_providers': list(provider_stats.keys()), - 'forensic_note': 'Refactored scanner with simplified recursion and forensic tracking' + 'session_id': self.session_id, + 'forensic_note': 'Enhanced scanner with reliable cross-process termination' }, 'graph_data': graph_data, 'forensic_audit': audit_trail, diff --git a/core/session_manager.py b/core/session_manager.py index d74c397..1ec0757 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -5,7 +5,7 @@ import time import uuid import redis import pickle -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, List from core.scanner import Scanner @@ -16,7 +16,7 @@ from core.scanner import Scanner class SessionManager: """ Manages multiple scanner instances for concurrent user sessions using Redis. - This allows session state to be shared across multiple Gunicorn worker processes. + Enhanced with reliable cross-process stop signal management and immediate state updates. """ def __init__(self, session_timeout_minutes: int = 60): @@ -57,6 +57,10 @@ class SessionManager: """Generates the Redis key for a session.""" return f"dnsrecon:session:{session_id}" + def _get_stop_signal_key(self, session_id: str) -> str: + """Generates the Redis key for a session's stop signal.""" + return f"dnsrecon:stop:{session_id}" + def create_session(self) -> str: """ Create a new user session and store it in Redis. @@ -69,6 +73,9 @@ class SessionManager: session_config = create_session_config() scanner_instance = Scanner(session_config=session_config) + # Set the session ID on the scanner for cross-process stop signal management + scanner_instance.session_id = session_id + session_data = { 'scanner': scanner_instance, 'config': session_config, @@ -84,38 +91,166 @@ class SessionManager: session_key = self._get_session_key(session_id) self.redis_client.setex(session_key, self.session_timeout, serialized_data) - print(f"Session {session_id} stored in Redis") + # Initialize stop signal as False + stop_key = self._get_stop_signal_key(session_id) + self.redis_client.setex(stop_key, self.session_timeout, b'0') + + print(f"Session {session_id} stored in Redis with stop signal initialized") return session_id except Exception as e: print(f"ERROR: Failed to create session {session_id}: {e}") raise + def set_stop_signal(self, session_id: str) -> bool: + """ + Set the stop signal for a session (cross-process safe). + + Args: + session_id: Session identifier + + Returns: + bool: True if signal was set successfully + """ + try: + stop_key = self._get_stop_signal_key(session_id) + # Set stop signal to '1' with the same TTL as the session + self.redis_client.setex(stop_key, self.session_timeout, b'1') + print(f"Stop signal set for session {session_id}") + return True + except Exception as e: + print(f"ERROR: Failed to set stop signal for session {session_id}: {e}") + return False + + def is_stop_requested(self, session_id: str) -> bool: + """ + Check if stop is requested for a session (cross-process safe). + + Args: + session_id: Session identifier + + Returns: + bool: True if stop is requested + """ + try: + stop_key = self._get_stop_signal_key(session_id) + value = self.redis_client.get(stop_key) + return value == b'1' if value is not None else False + except Exception as e: + print(f"ERROR: Failed to check stop signal for session {session_id}: {e}") + return False + + def clear_stop_signal(self, session_id: str) -> bool: + """ + Clear the stop signal for a session. + + Args: + session_id: Session identifier + + Returns: + bool: True if signal was cleared successfully + """ + try: + stop_key = self._get_stop_signal_key(session_id) + self.redis_client.setex(stop_key, self.session_timeout, b'0') + print(f"Stop signal cleared for session {session_id}") + return True + except Exception as e: + print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}") + return False + def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: """Retrieves and deserializes session data from Redis.""" - session_key = self._get_session_key(session_id) - serialized_data = self.redis_client.get(session_key) - if serialized_data: - return pickle.loads(serialized_data) - return None + try: + session_key = self._get_session_key(session_id) + serialized_data = self.redis_client.get(session_key) + if serialized_data: + session_data = pickle.loads(serialized_data) + # Ensure the scanner has the correct session ID for stop signal checking + if 'scanner' in session_data and session_data['scanner']: + session_data['scanner'].session_id = session_id + return session_data + return None + except Exception as e: + print(f"ERROR: Failed to get session data for {session_id}: {e}") + return None - def _save_session_data(self, session_id: str, session_data: Dict[str, Any]): - """Serializes and saves session data back to Redis with updated TTL.""" - session_key = self._get_session_key(session_id) - serialized_data = pickle.dumps(session_data) - self.redis_client.setex(session_key, self.session_timeout, serialized_data) + def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: + """ + Serializes and saves session data back to Redis with updated TTL. + + Returns: + bool: True if save was successful + """ + try: + session_key = self._get_session_key(session_id) + serialized_data = pickle.dumps(session_data) + result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) + return result + except Exception as e: + print(f"ERROR: Failed to save session data for {session_id}: {e}") + return False - def update_session_scanner(self, session_id: str, scanner: 'Scanner'): - """Updates just the scanner object in a session.""" - session_data = self._get_session_data(session_id) - if session_data: - session_data['scanner'] = scanner - # We don't need to update last_activity here, as that's for user interaction - self._save_session_data(session_id, session_data) + def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: + """ + Updates just the scanner object in a session with immediate persistence. + + Returns: + bool: True if update was successful + """ + try: + session_data = self._get_session_data(session_id) + if session_data: + # Ensure scanner has the session ID + scanner.session_id = session_id + session_data['scanner'] = scanner + session_data['last_activity'] = time.time() + + # Immediately save to Redis for GUI updates + success = self._save_session_data(session_id, session_data) + if success: + print(f"Scanner state updated for session {session_id} (status: {scanner.status})") + else: + print(f"WARNING: Failed to save scanner state for session {session_id}") + return success + else: + print(f"WARNING: Session {session_id} not found for scanner update") + return False + except Exception as e: + print(f"ERROR: Failed to update scanner for session {session_id}: {e}") + return False + + def update_scanner_status(self, session_id: str, status: str) -> bool: + """ + Quickly update just the scanner status for immediate GUI feedback. + + Args: + session_id: Session identifier + status: New scanner status + + Returns: + bool: True if update was successful + """ + try: + session_data = self._get_session_data(session_id) + if session_data and 'scanner' in session_data: + session_data['scanner'].status = status + session_data['last_activity'] = time.time() + + success = self._save_session_data(session_id, session_data) + if success: + print(f"Scanner status updated to '{status}' for session {session_id}") + else: + print(f"WARNING: Failed to save status update for session {session_id}") + return success + return False + except Exception as e: + print(f"ERROR: Failed to update scanner status for session {session_id}: {e}") + return False def get_session(self, session_id: str) -> Optional[Scanner]: """ - Get scanner instance for a session from Redis. + Get scanner instance for a session from Redis with enhanced session ID management. """ if not session_id: return None @@ -129,37 +264,151 @@ class SessionManager: session_data['last_activity'] = time.time() self._save_session_data(session_id, session_data) - return session_data.get('scanner') + scanner = session_data.get('scanner') + if scanner: + # Ensure the scanner can check the Redis-based stop signal + scanner.session_id = session_id + print(f"Retrieved scanner for session {session_id} (status: {scanner.status})") + + return scanner + + def get_session_status_only(self, session_id: str) -> Optional[str]: + """ + Get just the scanner status without full session retrieval (for performance). + + Args: + session_id: Session identifier + + Returns: + Scanner status string or None if not found + """ + try: + session_data = self._get_session_data(session_id) + if session_data and 'scanner' in session_data: + return session_data['scanner'].status + return None + except Exception as e: + print(f"ERROR: Failed to get session status for {session_id}: {e}") + return None def terminate_session(self, session_id: str) -> bool: """ - Terminate a specific session in Redis. + Terminate a specific session in Redis with reliable stop signal and immediate status update. """ - session_data = self._get_session_data(session_id) - if not session_data: - return False + print(f"=== TERMINATING SESSION {session_id} ===") + + try: + # First, set the stop signal + self.set_stop_signal(session_id) + + # Update scanner status to stopped immediately for GUI feedback + self.update_scanner_status(session_id, 'stopped') + + session_data = self._get_session_data(session_id) + if not session_data: + print(f"Session {session_id} not found") + return False - scanner = session_data.get('scanner') - if scanner and scanner.status == 'running': - scanner.stop_scan() - print(f"Stopped scan for session: {session_id}") - - # Delete from Redis - session_key = self._get_session_key(session_id) - self.redis_client.delete(session_key) - - print(f"Terminated and removed session from Redis: {session_id}") - return True + scanner = session_data.get('scanner') + if scanner and scanner.status == 'running': + print(f"Stopping scan for session: {session_id}") + # The scanner will check the Redis stop signal + scanner.stop_scan() + + # Update the scanner state immediately + self.update_session_scanner(session_id, scanner) + + # Wait a moment for graceful shutdown + time.sleep(0.5) + + # Delete session data and stop signal from Redis + session_key = self._get_session_key(session_id) + stop_key = self._get_stop_signal_key(session_id) + self.redis_client.delete(session_key) + self.redis_client.delete(stop_key) + + print(f"Terminated and removed session from Redis: {session_id}") + return True + + except Exception as e: + print(f"ERROR: Failed to terminate session {session_id}: {e}") + return False def _cleanup_loop(self) -> None: """ - Background thread to cleanup inactive sessions. - Redis's TTL (setex) handles most of this automatically. This loop is a failsafe. + Background thread to cleanup inactive sessions and orphaned stop signals. """ while True: - # Redis handles expiration automatically, so this loop can be simplified or removed - # For now, we'll keep it as a failsafe check for non-expiring keys if any get created by mistake - time.sleep(300) # Sleep for 5 minutes + try: + # Clean up orphaned stop signals + stop_keys = self.redis_client.keys("dnsrecon:stop:*") + for stop_key in stop_keys: + # Extract session ID from stop key + session_id = stop_key.decode('utf-8').split(':')[-1] + session_key = self._get_session_key(session_id) + + # If session doesn't exist but stop signal does, clean it up + if not self.redis_client.exists(session_key): + self.redis_client.delete(stop_key) + print(f"Cleaned up orphaned stop signal for session {session_id}") + + except Exception as e: + print(f"Error in cleanup loop: {e}") + + time.sleep(300) # Sleep for 5 minutes + + def list_active_sessions(self) -> List[Dict[str, Any]]: + """List all active sessions for admin purposes.""" + try: + session_keys = self.redis_client.keys("dnsrecon:session:*") + sessions = [] + + for session_key in session_keys: + session_id = session_key.decode('utf-8').split(':')[-1] + session_data = self._get_session_data(session_id) + + if session_data: + scanner = session_data.get('scanner') + sessions.append({ + 'session_id': session_id, + 'created_at': session_data.get('created_at'), + 'last_activity': session_data.get('last_activity'), + 'scanner_status': scanner.status if scanner else 'unknown', + 'current_target': scanner.current_target if scanner else None + }) + + return sessions + except Exception as e: + print(f"ERROR: Failed to list active sessions: {e}") + return [] + + def get_statistics(self) -> Dict[str, Any]: + """Get session manager statistics.""" + try: + session_keys = self.redis_client.keys("dnsrecon:session:*") + stop_keys = self.redis_client.keys("dnsrecon:stop:*") + + active_sessions = len(session_keys) + running_scans = 0 + + for session_key in session_keys: + session_id = session_key.decode('utf-8').split(':')[-1] + status = self.get_session_status_only(session_id) + if status == 'running': + running_scans += 1 + + return { + 'total_active_sessions': active_sessions, + 'running_scans': running_scans, + 'total_stop_signals': len(stop_keys) + } + except Exception as e: + print(f"ERROR: Failed to get statistics: {e}") + return { + 'total_active_sessions': 0, + 'running_scans': 0, + 'total_stop_signals': 0 + } # Global session manager instance session_manager = SessionManager(session_timeout_minutes=60) \ No newline at end of file diff --git a/providers/base_provider.py b/providers/base_provider.py index 6051339..5bb4ccd 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -163,11 +163,11 @@ class BaseProvider(ABC): target_indicator: str = "", max_retries: int = 3) -> Optional[requests.Response]: """ - Make a rate-limited HTTP request with forensic logging and retry logic. - Now supports cancellation via stop_event from scanner. + Make a rate-limited HTTP request with aggressive stop signal handling. + Terminates immediately when stop is requested, including during retries. """ # Check for cancellation before starting - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + if self._is_stop_requested(): print(f"Request cancelled before start: {url}") return None @@ -188,21 +188,24 @@ class BaseProvider(ABC): response.headers = cached_data['headers'] return response - for attempt in range(max_retries + 1): - # Check for cancellation before each attempt - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + # Determine effective max_retries based on stop signal + effective_max_retries = 0 if self._is_stop_requested() else max_retries + last_exception = None + + for attempt in range(effective_max_retries + 1): + # AGGRESSIVE: Check for cancellation before each attempt + if self._is_stop_requested(): print(f"Request cancelled during attempt {attempt + 1}: {url}") return None - # Apply rate limiting (but reduce wait time if cancellation is requested) - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): - break - - self.rate_limiter.wait_if_needed() + # Apply rate limiting with cancellation awareness + if not self._wait_with_cancellation_check(): + print(f"Request cancelled during rate limiting: {url}") + return None - # Check again after rate limiting - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): - print(f"Request cancelled after rate limiting: {url}") + # AGGRESSIVE: Final check before making HTTP request + if self._is_stop_requested(): + print(f"Request cancelled before HTTP call: {url}") return None start_time = time.time() @@ -219,10 +222,11 @@ class BaseProvider(ABC): print(f"Making {method} request to: {url} (attempt {attempt + 1})") - # Use shorter timeout if termination is requested + # AGGRESSIVE: Use much shorter timeout if termination is requested request_timeout = self.timeout - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): - request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested + if self._is_stop_requested(): + request_timeout = 2 # Max 2 seconds if termination requested + print(f"Stop requested - using short timeout: {request_timeout}s") # Make request if method.upper() == "GET": @@ -271,28 +275,28 @@ class BaseProvider(ABC): error = str(e) self.failed_requests += 1 print(f"Request failed (attempt {attempt + 1}): {error}") + last_exception = e - # Check for cancellation before retrying - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): - print(f"Request cancelled, not retrying: {url}") + # AGGRESSIVE: Immediately abort retries if stop requested + if self._is_stop_requested(): + print(f"Stop requested - aborting retries for: {url}") break - # Check if we should retry - if attempt < max_retries and self._should_retry(e): - backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s - print(f"Retrying in {backoff_time} seconds...") + # Check if we should retry (but only if stop not requested) + if attempt < effective_max_retries and self._should_retry(e): + # Use a longer, more respectful backoff for 429 errors + if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: + # Start with a 10-second backoff and increase exponentially + backoff_time = 10 * (2 ** attempt) + print(f"Rate limit hit. Retrying in {backoff_time} seconds...") + else: + backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors + print(f"Retrying in {backoff_time} seconds...") - # Shorter backoff if termination is requested - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): - backoff_time = min(0.5, backoff_time) - - # Sleep with cancellation checking - sleep_start = time.time() - while time.time() - sleep_start < backoff_time: - if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): - print(f"Request cancelled during backoff: {url}") - return None - time.sleep(0.1) # Check every 100ms + # AGGRESSIVE: Much shorter backoff and more frequent checking + if not self._sleep_with_cancellation_check(backoff_time): + print(f"Stop requested during backoff - aborting: {url}") + return None continue else: break @@ -301,6 +305,7 @@ class BaseProvider(ABC): error = f"Unexpected error: {str(e)}" self.failed_requests += 1 print(f"Unexpected error: {error}") + last_exception = e break # All attempts failed - log and return None @@ -316,8 +321,57 @@ class BaseProvider(ABC): target_indicator=target_indicator ) + if error and last_exception: + raise last_exception + return None + def _is_stop_requested(self) -> bool: + """ + Enhanced stop signal checking that handles both local and Redis-based signals. + """ + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + return True + return False + + + def _wait_with_cancellation_check(self) -> bool: + """ + Wait for rate limiting while aggressively checking for cancellation. + Returns False if cancelled during wait. + """ + current_time = time.time() + time_since_last = current_time - self.rate_limiter.last_request_time + + if time_since_last < self.rate_limiter.min_interval: + sleep_time = self.rate_limiter.min_interval - time_since_last + if not self._sleep_with_cancellation_check(sleep_time): + return False + + self.rate_limiter.last_request_time = time.time() + return True + + def _sleep_with_cancellation_check(self, sleep_time: float) -> bool: + """ + Sleep for the specified time while aggressively checking for cancellation. + + Args: + sleep_time: Time to sleep in seconds + + Returns: + bool: True if sleep completed, False if cancelled + """ + sleep_start = time.time() + check_interval = 0.05 # Check every 50ms for aggressive responsiveness + + while time.time() - sleep_start < sleep_time: + if self._is_stop_requested(): + return False + remaining_time = sleep_time - (time.time() - sleep_start) + time.sleep(min(check_interval, remaining_time)) + + return True + def set_stop_event(self, stop_event: threading.Event) -> None: """ Set the stop event for this provider to enable cancellation. @@ -337,15 +391,15 @@ class BaseProvider(ABC): Returns: True if the request should be retried """ - # Retry on connection errors, timeouts, and 5xx server errors + # Retry on connection errors and timeouts if isinstance(exception, (requests.exceptions.ConnectionError, requests.exceptions.Timeout)): return True if isinstance(exception, requests.exceptions.HTTPError): if hasattr(exception, 'response') and exception.response: - # Retry on server errors (5xx) but not client errors (4xx) - return exception.response.status_code >= 500 + # Retry on server errors (5xx) AND on rate-limiting errors (429) + return exception.response.status_code >= 500 or exception.response.status_code == 429 return False diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index 30ec5ea..1b0343b 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -157,8 +157,7 @@ class CrtShProvider(BaseProvider): def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query crt.sh for certificates containing the domain. - Creates domain-to-domain relationships and stores certificate data as metadata. - Now supports early termination via stop_event. + Enhanced with more frequent stop signal checking for reliable termination. """ if not _is_valid_domain(domain): return [] @@ -197,10 +196,10 @@ class CrtShProvider(BaseProvider): domain_certificates = {} all_discovered_domains = set() - # Process certificates and group by domain (with cancellation checks) + # Process certificates with enhanced cancellation checking for i, cert_data in enumerate(certificates): - # Check for cancellation every 10 certificates - if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): + # Check for cancellation every 5 certificates instead of 10 for faster response + if i % 5 == 0 and self._stop_event and self._stop_event.is_set(): print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}") break @@ -209,6 +208,11 @@ class CrtShProvider(BaseProvider): # Add all domains from this certificate to our tracking for cert_domain in cert_domains: + # Additional stop check during domain processing + if i % 20 == 0 and self._stop_event and self._stop_event.is_set(): + print(f"CrtSh domain processing cancelled for domain: {domain}") + break + if not _is_valid_domain(cert_domain): continue @@ -226,13 +230,13 @@ class CrtShProvider(BaseProvider): print(f"CrtSh query cancelled before relationship creation for domain: {domain}") return [] - # Create relationships from query domain to ALL discovered domains - for discovered_domain in all_discovered_domains: + # Create relationships from query domain to ALL discovered domains with stop checking + for i, discovered_domain in enumerate(all_discovered_domains): if discovered_domain == domain: continue # Skip self-relationships - # Check for cancellation during relationship creation - if self._stop_event and self._stop_event.is_set(): + # Check for cancellation every 10 relationships + if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): print(f"CrtSh relationship creation cancelled for domain: {domain}") break @@ -284,8 +288,6 @@ class CrtShProvider(BaseProvider): except json.JSONDecodeError as e: self.logger.logger.error(f"Failed to parse JSON response from crt.sh: {e}") - except Exception as e: - self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}") return relationships diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index f41e8f8..4bc9d4a 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -134,8 +134,6 @@ class ShodanProvider(BaseProvider): except json.JSONDecodeError as e: self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}") - except Exception as e: - self.logger.logger.error(f"Error querying Shodan for domain {domain}: {e}") return relationships @@ -231,9 +229,7 @@ class ShodanProvider(BaseProvider): except json.JSONDecodeError as e: self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}") - except Exception as e: - self.logger.logger.error(f"Error querying Shodan for IP {ip}: {e}") - + return relationships def search_by_organization(self, org_name: str) -> List[Dict[str, Any]]: diff --git a/static/js/main.js b/static/js/main.js index 795b3f2..0a451fc 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -246,7 +246,7 @@ class DNSReconApp { } /** - * Start a reconnaissance scan + * Enhanced start scan with better error handling */ async startScan(clearGraph = true) { console.log('=== STARTING SCAN ==='); @@ -292,7 +292,6 @@ class DNSReconApp { if (response.success) { this.currentSessionId = response.scan_id; - this.startPolling(); this.showSuccess('Reconnaissance scan started successfully'); if (clearGraph) { @@ -301,6 +300,9 @@ class DNSReconApp { console.log(`Scan started for ${targetDomain} with depth ${maxDepth}`); + // Start polling immediately with faster interval for responsiveness + this.startPolling(1000); + // Force an immediate status update console.log('Forcing immediate status update...'); setTimeout(() => { @@ -318,18 +320,43 @@ class DNSReconApp { this.setUIState('idle'); } } - /** - * Stop the current scan + * Enhanced scan stop with immediate UI feedback */ async stopScan() { try { console.log('Stopping scan...'); + + // Immediately disable stop button and show stopping state + if (this.elements.stopScan) { + this.elements.stopScan.disabled = true; + this.elements.stopScan.innerHTML = '[STOPPING]Stopping...'; + } + + // Show immediate feedback + this.showInfo('Stopping scan...'); + const response = await this.apiCall('/api/scan/stop', 'POST'); if (response.success) { this.showSuccess('Scan stop requested'); - console.log('Scan stop requested'); + console.log('Scan stop requested successfully'); + + // Force immediate status update + setTimeout(() => { + this.updateStatus(); + }, 100); + + // Continue polling for a bit to catch the status change + this.startPolling(500); // Fast polling to catch status change + + // Stop fast polling after 10 seconds + setTimeout(() => { + if (this.scanStatus === 'stopped' || this.scanStatus === 'idle') { + this.stopPolling(); + } + }, 10000); + } else { throw new Error(response.error || 'Failed to stop scan'); } @@ -337,6 +364,12 @@ class DNSReconApp { } catch (error) { console.error('Failed to stop scan:', error); this.showError(`Failed to stop scan: ${error.message}`); + + // Re-enable stop button on error + if (this.elements.stopScan) { + this.elements.stopScan.disabled = false; + this.elements.stopScan.innerHTML = '[STOP]Terminate Scan'; + } } } @@ -365,9 +398,9 @@ class DNSReconApp { } /** - * Start polling for scan updates + * Start polling for scan updates with configurable interval */ - startPolling() { + startPolling(interval = 2000) { console.log('=== STARTING POLLING ==='); if (this.pollInterval) { @@ -380,9 +413,9 @@ class DNSReconApp { this.updateStatus(); this.updateGraph(); this.loadProviders(); - }, 1000); // Poll every 1 second for debugging + }, interval); - console.log('Polling started with 1 second interval'); + console.log(`Polling started with ${interval}ms interval`); } /** @@ -397,7 +430,7 @@ class DNSReconApp { } /** - * Update scan status from server + * Enhanced status update with better error handling */ async updateStatus() { try { @@ -406,7 +439,7 @@ class DNSReconApp { console.log('Status response:', response); - if (response.success) { + if (response.success && response.status) { const status = response.status; console.log('Current scan status:', status.status); console.log('Current progress:', status.progress_percentage + '%'); @@ -423,6 +456,7 @@ class DNSReconApp { this.scanStatus = status.status; } else { console.error('Status update failed:', response); + // Don't show error for status updates to avoid spam } } catch (error) { @@ -551,7 +585,7 @@ class DNSReconApp { } /** - * Handle status changes + * Handle status changes with improved state synchronization * @param {string} newStatus - New scan status */ handleStatusChange(newStatus) { @@ -561,8 +595,8 @@ class DNSReconApp { case 'running': this.setUIState('scanning'); this.showSuccess('Scan is running'); - // Reset polling frequency for active scans - this.pollFrequency = 2000; + // Increase polling frequency for active scans + this.startPolling(1000); // Poll every 1 second for running scans this.updateConnectionStatus('active'); break; @@ -598,6 +632,10 @@ class DNSReconApp { this.stopPolling(); this.updateConnectionStatus('idle'); break; + + default: + console.warn(`Unknown status: ${newStatus}`); + break; } } @@ -633,8 +671,7 @@ class DNSReconApp { } /** - * Set UI state based on scan status - * @param {string} state - UI state + * Enhanced UI state management with immediate button updates */ setUIState(state) { console.log(`Setting UI state to: ${state}`); @@ -645,6 +682,7 @@ class DNSReconApp { if (this.elements.startScan) { this.elements.startScan.disabled = true; this.elements.startScan.classList.add('loading'); + this.elements.startScan.innerHTML = '[SCANNING]Scanning...'; } if (this.elements.addToGraph) { this.elements.addToGraph.disabled = true; @@ -653,6 +691,7 @@ class DNSReconApp { if (this.elements.stopScan) { this.elements.stopScan.disabled = false; this.elements.stopScan.classList.remove('loading'); + this.elements.stopScan.innerHTML = '[STOP]Terminate Scan'; } if (this.elements.targetDomain) this.elements.targetDomain.disabled = true; if (this.elements.maxDepth) this.elements.maxDepth.disabled = true; @@ -667,6 +706,7 @@ class DNSReconApp { if (this.elements.startScan) { this.elements.startScan.disabled = false; this.elements.startScan.classList.remove('loading'); + this.elements.startScan.innerHTML = '[RUN]Start Reconnaissance'; } if (this.elements.addToGraph) { this.elements.addToGraph.disabled = false; @@ -674,6 +714,7 @@ class DNSReconApp { } if (this.elements.stopScan) { this.elements.stopScan.disabled = true; + this.elements.stopScan.innerHTML = '[STOP]Terminate Scan'; } if (this.elements.targetDomain) this.elements.targetDomain.disabled = false; if (this.elements.maxDepth) this.elements.maxDepth.disabled = false;