This commit is contained in:
overcuriousity 2025-09-12 23:54:06 +02:00
parent 2d62191aa0
commit 03c52abd1b
8 changed files with 819 additions and 223 deletions

148
app.py
View File

@ -20,28 +20,20 @@ app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2) # 2 hour session
def get_user_scanner(): def get_user_scanner():
""" """
Get or create scanner instance for current user session with enhanced debugging. Enhanced user scanner retrieval with better error handling and debugging.
Returns:
Tuple of (session_id, scanner_instance)
""" """
# Get current Flask session info for debugging # Get current Flask session info for debugging
current_flask_session_id = session.get('dnsrecon_session_id') current_flask_session_id = session.get('dnsrecon_session_id')
client_ip = request.remote_addr client_ip = request.remote_addr
user_agent = request.headers.get('User-Agent', '')[:100] # Truncate for logging user_agent = request.headers.get('User-Agent', '')[:100] # Truncate for logging
print("=== SESSION DEBUG ===")
print(f"Client IP: {client_ip}")
print(f"User Agent: {user_agent}")
print(f"Flask Session ID: {current_flask_session_id}")
print(f"Flask Session Keys: {list(session.keys())}")
# Try to get existing session # Try to get existing session
if current_flask_session_id: if current_flask_session_id:
existing_scanner = session_manager.get_session(current_flask_session_id) existing_scanner = session_manager.get_session(current_flask_session_id)
if existing_scanner: if existing_scanner:
print(f"Using existing session: {current_flask_session_id}")
print(f"Scanner status: {existing_scanner.status}") print(f"Scanner status: {existing_scanner.status}")
# Ensure session ID is set
existing_scanner.session_id = current_flask_session_id
return current_flask_session_id, existing_scanner return current_flask_session_id, existing_scanner
else: else:
print(f"Session {current_flask_session_id} not found in session manager") print(f"Session {current_flask_session_id} not found in session manager")
@ -51,17 +43,23 @@ def get_user_scanner():
new_session_id = session_manager.create_session() new_session_id = session_manager.create_session()
new_scanner = session_manager.get_session(new_session_id) new_scanner = session_manager.get_session(new_session_id)
if not new_scanner:
print(f"ERROR: Failed to retrieve newly created session {new_session_id}")
raise Exception("Failed to create new scanner session")
# Store in Flask session # Store in Flask session
session['dnsrecon_session_id'] = new_session_id session['dnsrecon_session_id'] = new_session_id
session.permanent = True session.permanent = True
# Ensure session ID is set on scanner
new_scanner.session_id = new_session_id
print(f"Created new session: {new_session_id}") print(f"Created new session: {new_session_id}")
print(f"New scanner status: {new_scanner.status}") print(f"New scanner status: {new_scanner.status}")
print("=== END SESSION DEBUG ===") print("=== END SESSION DEBUG ===")
return new_session_id, new_scanner return new_session_id, new_scanner
@app.route('/') @app.route('/')
def index(): def index():
"""Serve the main web interface.""" """Serve the main web interface."""
@ -71,8 +69,7 @@ def index():
@app.route('/api/scan/start', methods=['POST']) @app.route('/api/scan/start', methods=['POST'])
def start_scan(): def start_scan():
""" """
Start a new reconnaissance scan for the current user session. Start a new reconnaissance scan with immediate GUI feedback.
Enhanced with better error handling and debugging.
""" """
print("=== API: /api/scan/start called ===") print("=== API: /api/scan/start called ===")
@ -111,9 +108,13 @@ def start_scan():
print("Validation passed, getting user scanner...") print("Validation passed, getting user scanner...")
# Get user-specific scanner with enhanced debugging # Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
scanner.user_session_id = user_session_id
# Ensure session ID is properly set
if not scanner.session_id:
scanner.session_id = user_session_id
print(f"Using session: {user_session_id}") print(f"Using session: {user_session_id}")
print(f"Scanner object ID: {id(scanner)}") print(f"Scanner object ID: {id(scanner)}")
@ -121,6 +122,8 @@ def start_scan():
print(f"Calling start_scan on scanner {id(scanner)}...") print(f"Calling start_scan on scanner {id(scanner)}...")
success = scanner.start_scan(target_domain, max_depth, clear_graph=clear_graph) success = scanner.start_scan(target_domain, max_depth, clear_graph=clear_graph)
# Immediately update session state regardless of success
session_manager.update_session_scanner(user_session_id, scanner)
if success: if success:
scan_session_id = scanner.logger.session_id scan_session_id = scanner.logger.session_id
@ -130,6 +133,7 @@ def start_scan():
'message': 'Scan started successfully', 'message': 'Scan started successfully',
'scan_id': scan_session_id, 'scan_id': scan_session_id,
'user_session_id': user_session_id, 'user_session_id': user_session_id,
'scanner_status': scanner.status,
'debug_info': { 'debug_info': {
'scanner_object_id': id(scanner), 'scanner_object_id': id(scanner),
'scanner_status': scanner.status 'scanner_status': scanner.status
@ -160,9 +164,10 @@ def start_scan():
'error': f'Internal server error: {str(e)}' 'error': f'Internal server error: {str(e)}'
}), 500 }), 500
@app.route('/api/scan/stop', methods=['POST']) @app.route('/api/scan/stop', methods=['POST'])
def stop_scan(): def stop_scan():
"""Stop the current scan for the user session.""" """Stop the current scan with immediate GUI feedback."""
print("=== API: /api/scan/stop called ===") print("=== API: /api/scan/stop called ===")
try: try:
@ -170,19 +175,37 @@ def stop_scan():
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
print(f"Stopping scan for session: {user_session_id}") print(f"Stopping scan for session: {user_session_id}")
if not scanner:
return jsonify({
'success': False,
'error': 'No scanner found for session'
}), 404
# Ensure session ID is set
if not scanner.session_id:
scanner.session_id = user_session_id
# Use the enhanced stop mechanism
success = scanner.stop_scan() success = scanner.stop_scan()
if success: # Also set the Redis stop signal directly for extra reliability
return jsonify({ session_manager.set_stop_signal(user_session_id)
'success': True,
'message': 'Scan stop requested', # Force immediate status update
'user_session_id': user_session_id session_manager.update_scanner_status(user_session_id, 'stopped')
})
else: # Update the full scanner state
return jsonify({ session_manager.update_session_scanner(user_session_id, scanner)
'success': True,
'message': 'No active scan to stop for this session' print(f"Stop scan completed. Success: {success}, Scanner status: {scanner.status}")
})
return jsonify({
'success': True,
'message': 'Scan stop requested - termination initiated',
'user_session_id': user_session_id,
'scanner_status': scanner.status,
'stop_method': 'enhanced_cross_process'
})
except Exception as e: except Exception as e:
print(f"ERROR: Exception in stop_scan endpoint: {e}") print(f"ERROR: Exception in stop_scan endpoint: {e}")
@ -195,14 +218,44 @@ def stop_scan():
@app.route('/api/scan/status', methods=['GET']) @app.route('/api/scan/status', methods=['GET'])
def get_scan_status(): def get_scan_status():
"""Get current scan status and progress for the user session.""" """Get current scan status with enhanced error handling."""
try: try:
# Get user-specific scanner # Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if not scanner:
# Return default idle status if no scanner
return jsonify({
'success': True,
'status': {
'status': 'idle',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'current_indicator': '',
'total_indicators_found': 0,
'indicators_processed': 0,
'progress_percentage': 0.0,
'enabled_providers': [],
'graph_statistics': {},
'user_session_id': user_session_id
}
})
# Ensure session ID is set
if not scanner.session_id:
scanner.session_id = user_session_id
status = scanner.get_scan_status() status = scanner.get_scan_status()
status['user_session_id'] = user_session_id status['user_session_id'] = user_session_id
# Additional debug info
status['debug_info'] = {
'scanner_object_id': id(scanner),
'session_id_set': bool(scanner.session_id),
'has_scan_thread': bool(scanner.scan_thread and scanner.scan_thread.is_alive())
}
return jsonify({ return jsonify({
'success': True, 'success': True,
'status': status 'status': status
@ -213,17 +266,42 @@ def get_scan_status():
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({
'success': False, 'success': False,
'error': f'Internal server error: {str(e)}' 'error': f'Internal server error: {str(e)}',
'fallback_status': {
'status': 'error',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'progress_percentage': 0.0
}
}), 500 }), 500
@app.route('/api/graph', methods=['GET']) @app.route('/api/graph', methods=['GET'])
def get_graph_data(): def get_graph_data():
"""Get current graph data for visualization for the user session.""" """Get current graph data with enhanced error handling."""
try: try:
# Get user-specific scanner # Get user-specific scanner
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if not scanner:
# Return empty graph if no scanner
return jsonify({
'success': True,
'graph': {
'nodes': [],
'edges': [],
'statistics': {
'node_count': 0,
'edge_count': 0,
'creation_time': datetime.now(timezone.utc).isoformat(),
'last_modified': datetime.now(timezone.utc).isoformat()
}
},
'user_session_id': user_session_id
})
graph_data = scanner.get_graph_data() graph_data = scanner.get_graph_data()
return jsonify({ return jsonify({
'success': True, 'success': True,
@ -236,10 +314,16 @@ def get_graph_data():
traceback.print_exc() traceback.print_exc()
return jsonify({ return jsonify({
'success': False, 'success': False,
'error': f'Internal server error: {str(e)}' 'error': f'Internal server error: {str(e)}',
'fallback_graph': {
'nodes': [],
'edges': [],
'statistics': {'node_count': 0, 'edge_count': 0}
}
}), 500 }), 500
@app.route('/api/export', methods=['GET']) @app.route('/api/export', methods=['GET'])
def export_results(): def export_results():
"""Export complete scan results as downloadable JSON for the user session.""" """Export complete scan results as downloadable JSON for the user session."""

View File

@ -19,7 +19,7 @@ class Config:
# Default settings # Default settings
self.default_recursion_depth = 2 self.default_recursion_depth = 2
self.default_timeout = 30 self.default_timeout = 10
self.max_concurrent_requests = 5 self.max_concurrent_requests = 5
self.large_entity_threshold = 100 self.large_entity_threshold = 100

View File

@ -2,8 +2,9 @@
import threading import threading
import traceback import traceback
import time
from typing import List, Set, Dict, Any, Tuple from typing import List, Set, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
@ -27,7 +28,7 @@ class ScanStatus:
class Scanner: class Scanner:
""" """
Main scanning orchestrator for DNSRecon passive reconnaissance. Main scanning orchestrator for DNSRecon passive reconnaissance.
REFACTORED: Simplified recursion with forensic provider state tracking. Enhanced with reliable cross-process termination capabilities.
""" """
def __init__(self, session_config=None): def __init__(self, session_config=None):
@ -49,6 +50,7 @@ class Scanner:
self.max_depth = 2 self.max_depth = 2
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.scan_thread = None self.scan_thread = None
self.session_id = None # Will be set by session manager
# Scanning progress tracking # Scanning progress tracking
self.total_indicators_found = 0 self.total_indicators_found = 0
@ -82,6 +84,42 @@ class Scanner:
traceback.print_exc() traceback.print_exc()
raise raise
def _is_stop_requested(self) -> bool:
"""
Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries.
"""
# Check local threading event first (fastest)
if self.stop_event.is_set():
return True
# Check Redis-based stop signal if session ID is available
if self.session_id:
try:
from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id)
except Exception as e:
print(f"Error checking Redis stop signal: {e}")
# Fall back to local event
return self.stop_event.is_set()
return False
def _set_stop_signal(self) -> None:
"""
Set stop signal both locally and in Redis.
"""
# Set local event
self.stop_event.set()
# Set Redis signal if session ID is available
if self.session_id:
try:
from core.session_manager import session_manager
session_manager.set_stop_signal(self.session_id)
except Exception as e:
print(f"Error setting Redis stop signal: {e}")
def __getstate__(self): def __getstate__(self):
"""Prepare object for pickling by excluding unpicklable attributes.""" """Prepare object for pickling by excluding unpicklable attributes."""
state = self.__dict__.copy() state = self.__dict__.copy()
@ -159,8 +197,9 @@ class Scanner:
print("Session configuration updated") print("Session configuration updated")
def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool:
"""Start a new reconnaissance scan with forensic tracking.""" """Start a new reconnaissance scan with immediate GUI feedback."""
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Session ID: {self.session_id}")
print(f"Initial scanner status: {self.status}") print(f"Initial scanner status: {self.status}")
# Clean up previous scan thread if needed # Clean up previous scan thread if needed
@ -172,11 +211,13 @@ class Scanner:
if self.scan_thread.is_alive(): if self.scan_thread.is_alive():
print("ERROR: The previous scan thread is unresponsive and could not be stopped.") print("ERROR: The previous scan thread is unresponsive and could not be stopped.")
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self._update_session_state()
return False return False
print("Previous scan thread terminated successfully.") print("Previous scan thread terminated successfully.")
# Reset state for new scan # Reset state for new scan
self.status = ScanStatus.IDLE self.status = ScanStatus.IDLE
self._update_session_state() # Update GUI immediately
print("Scanner state is now clean for a new scan.") print("Scanner state is now clean for a new scan.")
try: try:
@ -191,11 +232,20 @@ class Scanner:
self.current_target = target_domain.lower().strip() self.current_target = target_domain.lower().strip()
self.max_depth = max_depth self.max_depth = max_depth
self.current_depth = 0 self.current_depth = 0
# Clear both local and Redis stop signals
self.stop_event.clear() self.stop_event.clear()
if self.session_id:
from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
self.total_indicators_found = 0 self.total_indicators_found = 0
self.indicators_processed = 0 self.indicators_processed = 0
self.current_indicator = self.current_target self.current_indicator = self.current_target
# Update GUI with scan preparation
self._update_session_state()
# Start new forensic session # Start new forensic session
print(f"Starting new forensic session for scanner {id(self)}...") print(f"Starting new forensic session for scanner {id(self)}...")
self.logger = new_session() self.logger = new_session()
@ -216,16 +266,20 @@ class Scanner:
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc() traceback.print_exc()
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self._update_session_state() # Update failed status immediately
return False return False
def _execute_scan(self, target_domain: str, max_depth: int) -> None: def _execute_scan(self, target_domain: str, max_depth: int) -> None:
"""Execute the reconnaissance scan with simplified recursion and forensic tracking.""" """Execute the reconnaissance scan with frequent state updates for GUI."""
print(f"_execute_scan started for {target_domain} with depth {max_depth}") print(f"_execute_scan started for {target_domain} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_targets = set() processed_targets = set()
try: try:
self.status = ScanStatus.RUNNING self.status = ScanStatus.RUNNING
# Immediate status update for GUI
self._update_session_state()
enabled_providers = [provider.get_name() for provider in self.providers] enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target_domain, max_depth, enabled_providers) self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
self.graph.add_node(target_domain, NodeType.DOMAIN) self.graph.add_node(target_domain, NodeType.DOMAIN)
@ -235,11 +289,13 @@ class Scanner:
all_discovered_targets = {target_domain} all_discovered_targets = {target_domain}
for depth in range(max_depth + 1): for depth in range(max_depth + 1):
if self.stop_event.is_set(): if self._is_stop_requested():
print(f"Stop requested at depth {depth}") print(f"Stop requested at depth {depth}")
break break
self.current_depth = depth self.current_depth = depth
self._update_session_state()
targets_to_process = current_level_targets - processed_targets targets_to_process = current_level_targets - processed_targets
if not targets_to_process: if not targets_to_process:
print("No new targets to process at this level.") print("No new targets to process at this level.")
@ -247,8 +303,9 @@ class Scanner:
print(f"Processing depth level {depth} with {len(targets_to_process)} new targets") print(f"Processing depth level {depth} with {len(targets_to_process)} new targets")
self.total_indicators_found += len(targets_to_process) self.total_indicators_found += len(targets_to_process)
self._update_session_state()
target_results = self._process_targets_concurrent_forensic( target_results = self._process_targets_sequential_with_stop_checks(
targets_to_process, processed_targets, all_discovered_targets, depth targets_to_process, processed_targets, all_discovered_targets, depth
) )
processed_targets.update(targets_to_process) processed_targets.update(targets_to_process)
@ -256,31 +313,57 @@ class Scanner:
next_level_targets = set() next_level_targets = set()
for _target, new_targets in target_results: for _target, new_targets in target_results:
all_discovered_targets.update(new_targets) all_discovered_targets.update(new_targets)
# This is the critical change: add all new targets to the next level
next_level_targets.update(new_targets) next_level_targets.update(new_targets)
current_level_targets = next_level_targets # Filter out already processed targets before the next iteration
current_level_targets = next_level_targets - processed_targets
if self._is_stop_requested():
print(f"Stop requested after processing depth {depth}")
break
except Exception as e: except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}") print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc() traceback.print_exc()
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self._update_session_state() # Update failed status immediately
self.logger.logger.error(f"Scan failed: {e}") self.logger.logger.error(f"Scan failed: {e}")
finally: finally:
if self.stop_event.is_set(): if self._is_stop_requested():
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
else: else:
self.status = ScanStatus.COMPLETED self.status = ScanStatus.COMPLETED
# Final status update for GUI
self._update_session_state()
self.logger.log_scan_complete() self.logger.log_scan_complete()
self.executor.shutdown(wait=False, cancel_futures=True) if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True)
stats = self.graph.get_statistics() stats = self.graph.get_statistics()
print("Final scan statistics:") print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Targets processed: {len(processed_targets)}") print(f" - Targets processed: {len(processed_targets)}")
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.
This ensures the web interface sees real-time updates.
"""
if self.session_id:
try:
from core.session_manager import session_manager
success = session_manager.update_session_scanner(self.session_id, self)
if not success:
print(f"WARNING: Failed to update session state for {self.session_id}")
except Exception as e:
print(f"ERROR: Failed to update session state: {e}")
def _initialize_provider_states(self, target: str) -> None: def _initialize_provider_states(self, target: str) -> None:
"""Initialize provider states for forensic tracking.""" """Initialize provider states for forensic tracking."""
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() if not self.graph.graph.has_node(target):
return return
node_data = self.graph.graph.nodes[target] node_data = self.graph.graph.nodes[target]
@ -292,7 +375,6 @@ class Scanner:
def _should_recurse_on_target(self, target: str, processed_targets: Set[str], all_discovered: Set[str]) -> bool: def _should_recurse_on_target(self, target: str, processed_targets: Set[str], all_discovered: Set[str]) -> bool:
""" """
Simplified recursion logic: only recurse on valid IPs and domains that haven't been processed. Simplified recursion logic: only recurse on valid IPs and domains that haven't been processed.
FORENSIC PRINCIPLE: Clear, simple rules for what gets recursed.
""" """
# Don't recurse on already processed targets # Don't recurse on already processed targets
if target in processed_targets: if target in processed_targets:
@ -318,51 +400,129 @@ class Scanner:
return True return True
return False return False
def _process_targets_concurrent_forensic(self, targets: Set[str], processed_targets: Set[str], def _process_targets_sequential_with_stop_checks(self, targets: Set[str], processed_targets: Set[str],
all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]: all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]:
"""Process multiple targets concurrently with forensic provider state tracking.""" """
Process targets with controlled concurrency for both responsiveness and proper completion.
Balances termination responsiveness with avoiding race conditions.
"""
results = [] results = []
targets_to_process = targets - processed_targets targets_to_process = targets - processed_targets
if not targets_to_process: if not targets_to_process:
return results return results
print(f"Processing {len(targets_to_process)} targets concurrently with forensic tracking") print(f"Processing {len(targets_to_process)} targets with controlled concurrency")
future_to_target = { target_list = list(targets_to_process)
self.executor.submit(self._query_providers_forensic, target, current_depth): target active_futures: Dict[Future, str] = {}
for target in targets_to_process target_index = 0
} last_gui_update = time.time()
for future in as_completed(future_to_target): while target_index < len(target_list) or active_futures:
if self.stop_event.is_set(): # Check stop signal before any new work
future.cancel() if self._is_stop_requested():
continue print("Stop requested - canceling active futures and exiting")
target = future_to_target[future] for future in list(active_futures.keys()):
try: future.cancel()
new_targets = future.result() break
results.append((target, new_targets))
self.indicators_processed += 1
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
except (Exception, CancelledError) as e:
print(f"Error processing target {target}: {e}")
self._log_target_processing_error(target, str(e))
# Add this block to save the state to Redis # Submit new futures up to max_workers limit (controlled concurrency)
from core.session_manager import session_manager while len(active_futures) < self.max_workers and target_index < len(target_list):
if hasattr(self, 'user_session_id'): if self._is_stop_requested():
session_manager.update_session_scanner(self.user_session_id, self) break
target = target_list[target_index]
self.current_indicator = target
print(f"Submitting target {target_index + 1}/{len(target_list)}: {target}")
future = self.executor.submit(self._query_providers_forensic, target, current_depth)
active_futures[future] = target
target_index += 1
# Update GUI periodically
current_time = time.time()
if target_index % 2 == 0 or (current_time - last_gui_update) > 2.0:
self._update_session_state()
last_gui_update = current_time
# Wait for at least one future to complete (but don't wait forever)
if active_futures:
try:
# Wait for the first completion with reasonable timeout
completed_future = next(as_completed(active_futures.keys(), timeout=15.0))
target = active_futures[completed_future]
try:
new_targets = completed_future.result()
results.append((target, new_targets))
self.indicators_processed += 1
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
# Update GUI after each completion
current_time = time.time()
if (current_time - last_gui_update) > 1.0:
self._update_session_state()
last_gui_update = current_time
except Exception as e:
print(f"Error processing target {target}: {e}")
self._log_target_processing_error(target, str(e))
# Remove the completed future
del active_futures[completed_future]
except StopIteration:
# No futures completed within timeout - check stop signal and continue
if self._is_stop_requested():
print("Stop requested during timeout - canceling futures")
for future in list(active_futures.keys()):
future.cancel()
break
# Continue loop to wait for completions
except Exception as e:
# as_completed timeout or other error
if self._is_stop_requested():
print("Stop requested during future waiting")
for future in list(active_futures.keys()):
future.cancel()
break
# Check if any futures are actually done (in case of timeout exception)
completed_futures = [f for f in active_futures.keys() if f.done()]
for completed_future in completed_futures:
target = active_futures[completed_future]
try:
new_targets = completed_future.result()
results.append((target, new_targets))
self.indicators_processed += 1
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
except Exception as ex:
print(f"Error processing target {target}: {ex}")
self._log_target_processing_error(target, str(ex))
del active_futures[completed_future]
print(f"Completed processing all targets at depth {current_depth}")
# Final state update
self._update_session_state()
return results return results
def _query_providers_forensic(self, target: str, current_depth: int) -> Set[str]: def _query_providers_forensic(self, target: str, current_depth: int) -> Set[str]:
""" """
Query providers for a target with forensic state tracking and simplified recursion. Query providers for a target with enhanced stop signal checking.
REFACTORED: Simplified logic with complete forensic audit trail.
""" """
is_ip = _is_valid_ip(target) is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN target_type = NodeType.IP if is_ip else NodeType.DOMAIN
print(f"Querying providers for {target_type.value}: {target} at depth {current_depth}") print(f"Querying providers for {target_type.value}: {target} at depth {current_depth}")
# Early stop check
if self._is_stop_requested():
print(f"Stop requested before querying providers for {target}")
return set()
# Initialize node and provider states # Initialize node and provider states
self.graph.add_node(target, target_type) self.graph.add_node(target, target_type)
self._initialize_provider_states(target) self._initialize_provider_states(target)
@ -377,34 +537,27 @@ class Scanner:
self._log_no_eligible_providers(target, is_ip) self._log_no_eligible_providers(target, is_ip)
return new_targets return new_targets
# Query each eligible provider with forensic tracking # Query each eligible provider sequentially with stop checks
with ThreadPoolExecutor(max_workers=len(eligible_providers)) as provider_executor: for provider in eligible_providers:
future_to_provider = { if self._is_stop_requested():
provider_executor.submit(self._query_single_provider_forensic, provider, target, is_ip, current_depth): provider print(f"Stop requested while querying providers for {target}")
for provider in eligible_providers break
}
for future in as_completed(future_to_provider): try:
if self.stop_event.is_set(): provider_results = self._query_single_provider_forensic(provider, target, is_ip, current_depth)
future.cancel() if provider_results and not self._is_stop_requested():
continue discovered_targets = self._process_provider_results_forensic(
target, provider, provider_results, target_metadata, current_depth
provider = future_to_provider[future] )
try: new_targets.update(discovered_targets)
provider_results = future.result() except Exception as e:
if provider_results: self._log_provider_error(target, provider.get_name(), str(e))
discovered_targets = self._process_provider_results_forensic(
target, provider, provider_results, target_metadata, current_depth
)
new_targets.update(discovered_targets)
except (Exception, CancelledError) as e:
self._log_provider_error(target, provider.get_name(), str(e))
# Update node metadata
for node_id, metadata_dict in target_metadata.items(): for node_id, metadata_dict in target_metadata.items():
if self.graph.graph.has_node(node_id): if self.graph.graph.has_node(node_id):
node_is_ip = _is_valid_ip(node_id) node_is_ip = _is_valid_ip(node_id)
node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN
# This call updates the existing node with the new metadata
self.graph.add_node(node_id, node_type_to_add, metadata=metadata_dict) self.graph.add_node(node_id, node_type_to_add, metadata=metadata_dict)
return new_targets return new_targets
@ -428,7 +581,7 @@ class Scanner:
def _already_queried_provider(self, target: str, provider_name: str) -> bool: def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""Check if we already queried a provider for a target.""" """Check if we already queried a provider for a target."""
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() if not self.graph.graph.has_node(target):
return False return False
node_data = self.graph.graph.nodes[target] node_data = self.graph.graph.nodes[target]
@ -436,10 +589,15 @@ class Scanner:
return provider_name in provider_states return provider_name in provider_states
def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List: def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List:
"""Query a single provider with complete forensic logging.""" """Query a single provider with stop signal checking."""
provider_name = provider.get_name() provider_name = provider.get_name()
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
# Check stop signal before querying
if self._is_stop_requested():
print(f"Stop requested before querying {provider_name} for {target}")
return []
print(f"Querying {provider_name} for {target}") print(f"Querying {provider_name} for {target}")
# Log attempt # Log attempt
@ -452,6 +610,11 @@ class Scanner:
else: else:
results = provider.query_domain(target) results = provider.query_domain(target)
# Check stop signal after querying
if self._is_stop_requested():
print(f"Stop requested after querying {provider_name} for {target}")
return []
# Track successful state # Track successful state
self._update_provider_state(target, provider_name, 'success', len(results), None, start_time) self._update_provider_state(target, provider_name, 'success', len(results), None, start_time)
@ -467,7 +630,7 @@ class Scanner:
def _update_provider_state(self, target: str, provider_name: str, status: str, def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: str, start_time: datetime) -> None: results_count: int, error: str, start_time: datetime) -> None:
"""Update provider state in node metadata for forensic tracking.""" """Update provider state in node metadata for forensic tracking."""
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node() if not self.graph.graph.has_node(target):
return return
node_data = self.graph.graph.nodes[target] node_data = self.graph.graph.nodes[target]
@ -489,10 +652,15 @@ class Scanner:
def _process_provider_results_forensic(self, target: str, provider, results: List, def _process_provider_results_forensic(self, target: str, provider, results: List,
target_metadata: Dict, current_depth: int) -> Set[str]: target_metadata: Dict, current_depth: int) -> Set[str]:
"""Process provider results with large entity detection and forensic logging.""" """Process provider results with large entity detection and stop signal checking."""
provider_name = provider.get_name() provider_name = provider.get_name()
discovered_targets = set() discovered_targets = set()
# Check for stop signal before processing results
if self._is_stop_requested():
print(f"Stop requested before processing results from {provider_name} for {target}")
return discovered_targets
# Check for large entity threshold per provider # Check for large entity threshold per provider
if len(results) > self.config.large_entity_threshold: if len(results) > self.config.large_entity_threshold:
print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}") print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}")
@ -503,8 +671,10 @@ class Scanner:
# Process each relationship # Process each relationship
dns_records_to_create = {} dns_records_to_create = {}
for source, rel_target, rel_type, confidence, raw_data in results: for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results):
if self.stop_event.is_set(): # Check stop signal periodically during result processing
if i % 10 == 0 and self._is_stop_requested():
print(f"Stop requested while processing results from {provider_name} for {target}")
break break
# Enhanced forensic logging for each relationship # Enhanced forensic logging for each relationship
@ -539,7 +709,7 @@ class Scanner:
print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})") print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})")
discovered_targets.add(rel_target) discovered_targets.add(rel_target)
# *** NEW: Enrich the newly discovered domain *** # Enrich the newly discovered domain
self._collect_node_metadata_forensic(rel_target, provider_name, rel_type, source, raw_data, target_metadata[rel_target]) self._collect_node_metadata_forensic(rel_target, provider_name, rel_type, source, raw_data, target_metadata[rel_target])
else: else:
@ -691,25 +861,24 @@ class Scanner:
self.logger.logger.warning(f"No eligible providers for {target_type}: {target}") self.logger.logger.warning(f"No eligible providers for {target_type}: {target}")
def stop_scan(self) -> bool: def stop_scan(self) -> bool:
"""Request immediate scan termination with forensic logging.""" """Request immediate scan termination with immediate GUI feedback."""
try: try:
if not self.scan_thread or not self.scan_thread.is_alive():
print("No active scan thread to stop.")
if self.status == ScanStatus.RUNNING:
self.status = ScanStatus.STOPPED
return False
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
self.logger.logger.info("Scan termination requested by user") self.logger.logger.info("Scan termination requested by user")
# Set both local and Redis stop signals
self._set_stop_signal()
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
self.stop_event.set()
# Immediately update GUI with stopped status
self._update_session_state()
# Cancel executor futures if running
if self.executor: if self.executor:
print("Shutting down executor with immediate cancellation...") print("Shutting down executor with immediate cancellation...")
self.executor.shutdown(wait=False, cancel_futures=True) self.executor.shutdown(wait=False, cancel_futures=True)
print("Termination signal sent. The scan thread will stop shortly.") print("Termination signals sent. The scan will stop as soon as possible.")
return True return True
except Exception as e: except Exception as e:
@ -774,7 +943,8 @@ class Scanner:
'final_status': self.status, 'final_status': self.status,
'total_indicators_processed': self.indicators_processed, 'total_indicators_processed': self.indicators_processed,
'enabled_providers': list(provider_stats.keys()), 'enabled_providers': list(provider_stats.keys()),
'forensic_note': 'Refactored scanner with simplified recursion and forensic tracking' 'session_id': self.session_id,
'forensic_note': 'Enhanced scanner with reliable cross-process termination'
}, },
'graph_data': graph_data, 'graph_data': graph_data,
'forensic_audit': audit_trail, 'forensic_audit': audit_trail,

View File

@ -5,7 +5,7 @@ import time
import uuid import uuid
import redis import redis
import pickle import pickle
from typing import Dict, Optional, Any from typing import Dict, Optional, Any, List
from core.scanner import Scanner from core.scanner import Scanner
@ -16,7 +16,7 @@ from core.scanner import Scanner
class SessionManager: class SessionManager:
""" """
Manages multiple scanner instances for concurrent user sessions using Redis. Manages multiple scanner instances for concurrent user sessions using Redis.
This allows session state to be shared across multiple Gunicorn worker processes. Enhanced with reliable cross-process stop signal management and immediate state updates.
""" """
def __init__(self, session_timeout_minutes: int = 60): def __init__(self, session_timeout_minutes: int = 60):
@ -57,6 +57,10 @@ class SessionManager:
"""Generates the Redis key for a session.""" """Generates the Redis key for a session."""
return f"dnsrecon:session:{session_id}" return f"dnsrecon:session:{session_id}"
def _get_stop_signal_key(self, session_id: str) -> str:
"""Generates the Redis key for a session's stop signal."""
return f"dnsrecon:stop:{session_id}"
def create_session(self) -> str: def create_session(self) -> str:
""" """
Create a new user session and store it in Redis. Create a new user session and store it in Redis.
@ -69,6 +73,9 @@ class SessionManager:
session_config = create_session_config() session_config = create_session_config()
scanner_instance = Scanner(session_config=session_config) scanner_instance = Scanner(session_config=session_config)
# Set the session ID on the scanner for cross-process stop signal management
scanner_instance.session_id = session_id
session_data = { session_data = {
'scanner': scanner_instance, 'scanner': scanner_instance,
'config': session_config, 'config': session_config,
@ -84,38 +91,166 @@ class SessionManager:
session_key = self._get_session_key(session_id) session_key = self._get_session_key(session_id)
self.redis_client.setex(session_key, self.session_timeout, serialized_data) self.redis_client.setex(session_key, self.session_timeout, serialized_data)
print(f"Session {session_id} stored in Redis") # Initialize stop signal as False
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'0')
print(f"Session {session_id} stored in Redis with stop signal initialized")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}") print(f"ERROR: Failed to create session {session_id}: {e}")
raise raise
def set_stop_signal(self, session_id: str) -> bool:
"""
Set the stop signal for a session (cross-process safe).
Args:
session_id: Session identifier
Returns:
bool: True if signal was set successfully
"""
try:
stop_key = self._get_stop_signal_key(session_id)
# Set stop signal to '1' with the same TTL as the session
self.redis_client.setex(stop_key, self.session_timeout, b'1')
print(f"Stop signal set for session {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to set stop signal for session {session_id}: {e}")
return False
def is_stop_requested(self, session_id: str) -> bool:
"""
Check if stop is requested for a session (cross-process safe).
Args:
session_id: Session identifier
Returns:
bool: True if stop is requested
"""
try:
stop_key = self._get_stop_signal_key(session_id)
value = self.redis_client.get(stop_key)
return value == b'1' if value is not None else False
except Exception as e:
print(f"ERROR: Failed to check stop signal for session {session_id}: {e}")
return False
def clear_stop_signal(self, session_id: str) -> bool:
"""
Clear the stop signal for a session.
Args:
session_id: Session identifier
Returns:
bool: True if signal was cleared successfully
"""
try:
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.setex(stop_key, self.session_timeout, b'0')
print(f"Stop signal cleared for session {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to clear stop signal for session {session_id}: {e}")
return False
def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: def _get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Retrieves and deserializes session data from Redis.""" """Retrieves and deserializes session data from Redis."""
session_key = self._get_session_key(session_id) try:
serialized_data = self.redis_client.get(session_key) session_key = self._get_session_key(session_id)
if serialized_data: serialized_data = self.redis_client.get(session_key)
return pickle.loads(serialized_data) if serialized_data:
return None session_data = pickle.loads(serialized_data)
# Ensure the scanner has the correct session ID for stop signal checking
if 'scanner' in session_data and session_data['scanner']:
session_data['scanner'].session_id = session_id
return session_data
return None
except Exception as e:
print(f"ERROR: Failed to get session data for {session_id}: {e}")
return None
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]): def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
"""Serializes and saves session data back to Redis with updated TTL.""" """
session_key = self._get_session_key(session_id) Serializes and saves session data back to Redis with updated TTL.
serialized_data = pickle.dumps(session_data)
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
def update_session_scanner(self, session_id: str, scanner: 'Scanner'): Returns:
"""Updates just the scanner object in a session.""" bool: True if save was successful
session_data = self._get_session_data(session_id) """
if session_data: try:
session_data['scanner'] = scanner session_key = self._get_session_key(session_id)
# We don't need to update last_activity here, as that's for user interaction serialized_data = pickle.dumps(session_data)
self._save_session_data(session_id, session_data) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
return result
except Exception as e:
print(f"ERROR: Failed to save session data for {session_id}: {e}")
return False
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
"""
Updates just the scanner object in a session with immediate persistence.
Returns:
bool: True if update was successful
"""
try:
session_data = self._get_session_data(session_id)
if session_data:
# Ensure scanner has the session ID
scanner.session_id = session_id
session_data['scanner'] = scanner
session_data['last_activity'] = time.time()
# Immediately save to Redis for GUI updates
success = self._save_session_data(session_id, session_data)
if success:
print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
else:
print(f"WARNING: Failed to save scanner state for session {session_id}")
return success
else:
print(f"WARNING: Session {session_id} not found for scanner update")
return False
except Exception as e:
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
return False
def update_scanner_status(self, session_id: str, status: str) -> bool:
"""
Quickly update just the scanner status for immediate GUI feedback.
Args:
session_id: Session identifier
status: New scanner status
Returns:
bool: True if update was successful
"""
try:
session_data = self._get_session_data(session_id)
if session_data and 'scanner' in session_data:
session_data['scanner'].status = status
session_data['last_activity'] = time.time()
success = self._save_session_data(session_id, session_data)
if success:
print(f"Scanner status updated to '{status}' for session {session_id}")
else:
print(f"WARNING: Failed to save status update for session {session_id}")
return success
return False
except Exception as e:
print(f"ERROR: Failed to update scanner status for session {session_id}: {e}")
return False
def get_session(self, session_id: str) -> Optional[Scanner]: def get_session(self, session_id: str) -> Optional[Scanner]:
""" """
Get scanner instance for a session from Redis. Get scanner instance for a session from Redis with enhanced session ID management.
""" """
if not session_id: if not session_id:
return None return None
@ -129,37 +264,151 @@ class SessionManager:
session_data['last_activity'] = time.time() session_data['last_activity'] = time.time()
self._save_session_data(session_id, session_data) self._save_session_data(session_id, session_data)
return session_data.get('scanner') scanner = session_data.get('scanner')
if scanner:
# Ensure the scanner can check the Redis-based stop signal
scanner.session_id = session_id
print(f"Retrieved scanner for session {session_id} (status: {scanner.status})")
return scanner
def get_session_status_only(self, session_id: str) -> Optional[str]:
"""
Get just the scanner status without full session retrieval (for performance).
Args:
session_id: Session identifier
Returns:
Scanner status string or None if not found
"""
try:
session_data = self._get_session_data(session_id)
if session_data and 'scanner' in session_data:
return session_data['scanner'].status
return None
except Exception as e:
print(f"ERROR: Failed to get session status for {session_id}: {e}")
return None
def terminate_session(self, session_id: str) -> bool: def terminate_session(self, session_id: str) -> bool:
""" """
Terminate a specific session in Redis. Terminate a specific session in Redis with reliable stop signal and immediate status update.
""" """
session_data = self._get_session_data(session_id) print(f"=== TERMINATING SESSION {session_id} ===")
if not session_data:
try:
# First, set the stop signal
self.set_stop_signal(session_id)
# Update scanner status to stopped immediately for GUI feedback
self.update_scanner_status(session_id, 'stopped')
session_data = self._get_session_data(session_id)
if not session_data:
print(f"Session {session_id} not found")
return False
scanner = session_data.get('scanner')
if scanner and scanner.status == 'running':
print(f"Stopping scan for session: {session_id}")
# The scanner will check the Redis stop signal
scanner.stop_scan()
# Update the scanner state immediately
self.update_session_scanner(session_id, scanner)
# Wait a moment for graceful shutdown
time.sleep(0.5)
# Delete session data and stop signal from Redis
session_key = self._get_session_key(session_id)
stop_key = self._get_stop_signal_key(session_id)
self.redis_client.delete(session_key)
self.redis_client.delete(stop_key)
print(f"Terminated and removed session from Redis: {session_id}")
return True
except Exception as e:
print(f"ERROR: Failed to terminate session {session_id}: {e}")
return False return False
scanner = session_data.get('scanner')
if scanner and scanner.status == 'running':
scanner.stop_scan()
print(f"Stopped scan for session: {session_id}")
# Delete from Redis
session_key = self._get_session_key(session_id)
self.redis_client.delete(session_key)
print(f"Terminated and removed session from Redis: {session_id}")
return True
def _cleanup_loop(self) -> None: def _cleanup_loop(self) -> None:
""" """
Background thread to cleanup inactive sessions. Background thread to cleanup inactive sessions and orphaned stop signals.
Redis's TTL (setex) handles most of this automatically. This loop is a failsafe.
""" """
while True: while True:
# Redis handles expiration automatically, so this loop can be simplified or removed try:
# For now, we'll keep it as a failsafe check for non-expiring keys if any get created by mistake # Clean up orphaned stop signals
time.sleep(300) # Sleep for 5 minutes stop_keys = self.redis_client.keys("dnsrecon:stop:*")
for stop_key in stop_keys:
# Extract session ID from stop key
session_id = stop_key.decode('utf-8').split(':')[-1]
session_key = self._get_session_key(session_id)
# If session doesn't exist but stop signal does, clean it up
if not self.redis_client.exists(session_key):
self.redis_client.delete(stop_key)
print(f"Cleaned up orphaned stop signal for session {session_id}")
except Exception as e:
print(f"Error in cleanup loop: {e}")
time.sleep(300) # Sleep for 5 minutes
def list_active_sessions(self) -> List[Dict[str, Any]]:
"""List all active sessions for admin purposes."""
try:
session_keys = self.redis_client.keys("dnsrecon:session:*")
sessions = []
for session_key in session_keys:
session_id = session_key.decode('utf-8').split(':')[-1]
session_data = self._get_session_data(session_id)
if session_data:
scanner = session_data.get('scanner')
sessions.append({
'session_id': session_id,
'created_at': session_data.get('created_at'),
'last_activity': session_data.get('last_activity'),
'scanner_status': scanner.status if scanner else 'unknown',
'current_target': scanner.current_target if scanner else None
})
return sessions
except Exception as e:
print(f"ERROR: Failed to list active sessions: {e}")
return []
def get_statistics(self) -> Dict[str, Any]:
"""Get session manager statistics."""
try:
session_keys = self.redis_client.keys("dnsrecon:session:*")
stop_keys = self.redis_client.keys("dnsrecon:stop:*")
active_sessions = len(session_keys)
running_scans = 0
for session_key in session_keys:
session_id = session_key.decode('utf-8').split(':')[-1]
status = self.get_session_status_only(session_id)
if status == 'running':
running_scans += 1
return {
'total_active_sessions': active_sessions,
'running_scans': running_scans,
'total_stop_signals': len(stop_keys)
}
except Exception as e:
print(f"ERROR: Failed to get statistics: {e}")
return {
'total_active_sessions': 0,
'running_scans': 0,
'total_stop_signals': 0
}
# Global session manager instance # Global session manager instance
session_manager = SessionManager(session_timeout_minutes=60) session_manager = SessionManager(session_timeout_minutes=60)

View File

@ -163,11 +163,11 @@ class BaseProvider(ABC):
target_indicator: str = "", target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]: max_retries: int = 3) -> Optional[requests.Response]:
""" """
Make a rate-limited HTTP request with forensic logging and retry logic. Make a rate-limited HTTP request with aggressive stop signal handling.
Now supports cancellation via stop_event from scanner. Terminates immediately when stop is requested, including during retries.
""" """
# Check for cancellation before starting # Check for cancellation before starting
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): if self._is_stop_requested():
print(f"Request cancelled before start: {url}") print(f"Request cancelled before start: {url}")
return None return None
@ -188,21 +188,24 @@ class BaseProvider(ABC):
response.headers = cached_data['headers'] response.headers = cached_data['headers']
return response return response
for attempt in range(max_retries + 1): # Determine effective max_retries based on stop signal
# Check for cancellation before each attempt effective_max_retries = 0 if self._is_stop_requested() else max_retries
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): last_exception = None
for attempt in range(effective_max_retries + 1):
# AGGRESSIVE: Check for cancellation before each attempt
if self._is_stop_requested():
print(f"Request cancelled during attempt {attempt + 1}: {url}") print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None return None
# Apply rate limiting (but reduce wait time if cancellation is requested) # Apply rate limiting with cancellation awareness
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): if not self._wait_with_cancellation_check():
break print(f"Request cancelled during rate limiting: {url}")
return None
self.rate_limiter.wait_if_needed() # AGGRESSIVE: Final check before making HTTP request
if self._is_stop_requested():
# Check again after rate limiting print(f"Request cancelled before HTTP call: {url}")
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled after rate limiting: {url}")
return None return None
start_time = time.time() start_time = time.time()
@ -219,10 +222,11 @@ class BaseProvider(ABC):
print(f"Making {method} request to: {url} (attempt {attempt + 1})") print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# Use shorter timeout if termination is requested # AGGRESSIVE: Use much shorter timeout if termination is requested
request_timeout = self.timeout request_timeout = self.timeout
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): if self._is_stop_requested():
request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested request_timeout = 2 # Max 2 seconds if termination requested
print(f"Stop requested - using short timeout: {request_timeout}s")
# Make request # Make request
if method.upper() == "GET": if method.upper() == "GET":
@ -271,28 +275,28 @@ class BaseProvider(ABC):
error = str(e) error = str(e)
self.failed_requests += 1 self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}") print(f"Request failed (attempt {attempt + 1}): {error}")
last_exception = e
# Check for cancellation before retrying # AGGRESSIVE: Immediately abort retries if stop requested
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): if self._is_stop_requested():
print(f"Request cancelled, not retrying: {url}") print(f"Stop requested - aborting retries for: {url}")
break break
# Check if we should retry # Check if we should retry (but only if stop not requested)
if attempt < max_retries and self._should_retry(e): if attempt < effective_max_retries and self._should_retry(e):
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s # Use a longer, more respectful backoff for 429 errors
print(f"Retrying in {backoff_time} seconds...") if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
# Start with a 10-second backoff and increase exponentially
backoff_time = 10 * (2 ** attempt)
print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
else:
backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors
print(f"Retrying in {backoff_time} seconds...")
# Shorter backoff if termination is requested # AGGRESSIVE: Much shorter backoff and more frequent checking
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): if not self._sleep_with_cancellation_check(backoff_time):
backoff_time = min(0.5, backoff_time) print(f"Stop requested during backoff - aborting: {url}")
return None
# Sleep with cancellation checking
sleep_start = time.time()
while time.time() - sleep_start < backoff_time:
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled during backoff: {url}")
return None
time.sleep(0.1) # Check every 100ms
continue continue
else: else:
break break
@ -301,6 +305,7 @@ class BaseProvider(ABC):
error = f"Unexpected error: {str(e)}" error = f"Unexpected error: {str(e)}"
self.failed_requests += 1 self.failed_requests += 1
print(f"Unexpected error: {error}") print(f"Unexpected error: {error}")
last_exception = e
break break
# All attempts failed - log and return None # All attempts failed - log and return None
@ -316,8 +321,57 @@ class BaseProvider(ABC):
target_indicator=target_indicator target_indicator=target_indicator
) )
if error and last_exception:
raise last_exception
return None return None
def _is_stop_requested(self) -> bool:
"""
Enhanced stop signal checking that handles both local and Redis-based signals.
"""
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
return True
return False
def _wait_with_cancellation_check(self) -> bool:
"""
Wait for rate limiting while aggressively checking for cancellation.
Returns False if cancelled during wait.
"""
current_time = time.time()
time_since_last = current_time - self.rate_limiter.last_request_time
if time_since_last < self.rate_limiter.min_interval:
sleep_time = self.rate_limiter.min_interval - time_since_last
if not self._sleep_with_cancellation_check(sleep_time):
return False
self.rate_limiter.last_request_time = time.time()
return True
def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
"""
Sleep for the specified time while aggressively checking for cancellation.
Args:
sleep_time: Time to sleep in seconds
Returns:
bool: True if sleep completed, False if cancelled
"""
sleep_start = time.time()
check_interval = 0.05 # Check every 50ms for aggressive responsiveness
while time.time() - sleep_start < sleep_time:
if self._is_stop_requested():
return False
remaining_time = sleep_time - (time.time() - sleep_start)
time.sleep(min(check_interval, remaining_time))
return True
def set_stop_event(self, stop_event: threading.Event) -> None: def set_stop_event(self, stop_event: threading.Event) -> None:
""" """
Set the stop event for this provider to enable cancellation. Set the stop event for this provider to enable cancellation.
@ -337,15 +391,15 @@ class BaseProvider(ABC):
Returns: Returns:
True if the request should be retried True if the request should be retried
""" """
# Retry on connection errors, timeouts, and 5xx server errors # Retry on connection errors and timeouts
if isinstance(exception, (requests.exceptions.ConnectionError, if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)): requests.exceptions.Timeout)):
return True return True
if isinstance(exception, requests.exceptions.HTTPError): if isinstance(exception, requests.exceptions.HTTPError):
if hasattr(exception, 'response') and exception.response: if hasattr(exception, 'response') and exception.response:
# Retry on server errors (5xx) but not client errors (4xx) # Retry on server errors (5xx) AND on rate-limiting errors (429)
return exception.response.status_code >= 500 return exception.response.status_code >= 500 or exception.response.status_code == 429
return False return False

View File

@ -157,8 +157,7 @@ class CrtShProvider(BaseProvider):
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
""" """
Query crt.sh for certificates containing the domain. Query crt.sh for certificates containing the domain.
Creates domain-to-domain relationships and stores certificate data as metadata. Enhanced with more frequent stop signal checking for reliable termination.
Now supports early termination via stop_event.
""" """
if not _is_valid_domain(domain): if not _is_valid_domain(domain):
return [] return []
@ -197,10 +196,10 @@ class CrtShProvider(BaseProvider):
domain_certificates = {} domain_certificates = {}
all_discovered_domains = set() all_discovered_domains = set()
# Process certificates and group by domain (with cancellation checks) # Process certificates with enhanced cancellation checking
for i, cert_data in enumerate(certificates): for i, cert_data in enumerate(certificates):
# Check for cancellation every 10 certificates # Check for cancellation every 5 certificates instead of 10 for faster response
if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): if i % 5 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}") print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}")
break break
@ -209,6 +208,11 @@ class CrtShProvider(BaseProvider):
# Add all domains from this certificate to our tracking # Add all domains from this certificate to our tracking
for cert_domain in cert_domains: for cert_domain in cert_domains:
# Additional stop check during domain processing
if i % 20 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh domain processing cancelled for domain: {domain}")
break
if not _is_valid_domain(cert_domain): if not _is_valid_domain(cert_domain):
continue continue
@ -226,13 +230,13 @@ class CrtShProvider(BaseProvider):
print(f"CrtSh query cancelled before relationship creation for domain: {domain}") print(f"CrtSh query cancelled before relationship creation for domain: {domain}")
return [] return []
# Create relationships from query domain to ALL discovered domains # Create relationships from query domain to ALL discovered domains with stop checking
for discovered_domain in all_discovered_domains: for i, discovered_domain in enumerate(all_discovered_domains):
if discovered_domain == domain: if discovered_domain == domain:
continue # Skip self-relationships continue # Skip self-relationships
# Check for cancellation during relationship creation # Check for cancellation every 10 relationships
if self._stop_event and self._stop_event.is_set(): if i % 10 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh relationship creation cancelled for domain: {domain}") print(f"CrtSh relationship creation cancelled for domain: {domain}")
break break
@ -284,8 +288,6 @@ class CrtShProvider(BaseProvider):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
self.logger.logger.error(f"Failed to parse JSON response from crt.sh: {e}") self.logger.logger.error(f"Failed to parse JSON response from crt.sh: {e}")
except Exception as e:
self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}")
return relationships return relationships

View File

@ -134,8 +134,6 @@ class ShodanProvider(BaseProvider):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}") self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}")
except Exception as e:
self.logger.logger.error(f"Error querying Shodan for domain {domain}: {e}")
return relationships return relationships
@ -231,8 +229,6 @@ class ShodanProvider(BaseProvider):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}") self.logger.logger.error(f"Failed to parse JSON response from Shodan: {e}")
except Exception as e:
self.logger.logger.error(f"Error querying Shodan for IP {ip}: {e}")
return relationships return relationships

View File

@ -246,7 +246,7 @@ class DNSReconApp {
} }
/** /**
* Start a reconnaissance scan * Enhanced start scan with better error handling
*/ */
async startScan(clearGraph = true) { async startScan(clearGraph = true) {
console.log('=== STARTING SCAN ==='); console.log('=== STARTING SCAN ===');
@ -292,7 +292,6 @@ class DNSReconApp {
if (response.success) { if (response.success) {
this.currentSessionId = response.scan_id; this.currentSessionId = response.scan_id;
this.startPolling();
this.showSuccess('Reconnaissance scan started successfully'); this.showSuccess('Reconnaissance scan started successfully');
if (clearGraph) { if (clearGraph) {
@ -301,6 +300,9 @@ class DNSReconApp {
console.log(`Scan started for ${targetDomain} with depth ${maxDepth}`); console.log(`Scan started for ${targetDomain} with depth ${maxDepth}`);
// Start polling immediately with faster interval for responsiveness
this.startPolling(1000);
// Force an immediate status update // Force an immediate status update
console.log('Forcing immediate status update...'); console.log('Forcing immediate status update...');
setTimeout(() => { setTimeout(() => {
@ -318,18 +320,43 @@ class DNSReconApp {
this.setUIState('idle'); this.setUIState('idle');
} }
} }
/** /**
* Stop the current scan * Enhanced scan stop with immediate UI feedback
*/ */
async stopScan() { async stopScan() {
try { try {
console.log('Stopping scan...'); console.log('Stopping scan...');
// Immediately disable stop button and show stopping state
if (this.elements.stopScan) {
this.elements.stopScan.disabled = true;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOPPING]</span><span>Stopping...</span>';
}
// Show immediate feedback
this.showInfo('Stopping scan...');
const response = await this.apiCall('/api/scan/stop', 'POST'); const response = await this.apiCall('/api/scan/stop', 'POST');
if (response.success) { if (response.success) {
this.showSuccess('Scan stop requested'); this.showSuccess('Scan stop requested');
console.log('Scan stop requested'); console.log('Scan stop requested successfully');
// Force immediate status update
setTimeout(() => {
this.updateStatus();
}, 100);
// Continue polling for a bit to catch the status change
this.startPolling(500); // Fast polling to catch status change
// Stop fast polling after 10 seconds
setTimeout(() => {
if (this.scanStatus === 'stopped' || this.scanStatus === 'idle') {
this.stopPolling();
}
}, 10000);
} else { } else {
throw new Error(response.error || 'Failed to stop scan'); throw new Error(response.error || 'Failed to stop scan');
} }
@ -337,6 +364,12 @@ class DNSReconApp {
} catch (error) { } catch (error) {
console.error('Failed to stop scan:', error); console.error('Failed to stop scan:', error);
this.showError(`Failed to stop scan: ${error.message}`); this.showError(`Failed to stop scan: ${error.message}`);
// Re-enable stop button on error
if (this.elements.stopScan) {
this.elements.stopScan.disabled = false;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>';
}
} }
} }
@ -365,9 +398,9 @@ class DNSReconApp {
} }
/** /**
* Start polling for scan updates * Start polling for scan updates with configurable interval
*/ */
startPolling() { startPolling(interval = 2000) {
console.log('=== STARTING POLLING ==='); console.log('=== STARTING POLLING ===');
if (this.pollInterval) { if (this.pollInterval) {
@ -380,9 +413,9 @@ class DNSReconApp {
this.updateStatus(); this.updateStatus();
this.updateGraph(); this.updateGraph();
this.loadProviders(); this.loadProviders();
}, 1000); // Poll every 1 second for debugging }, interval);
console.log('Polling started with 1 second interval'); console.log(`Polling started with ${interval}ms interval`);
} }
/** /**
@ -397,7 +430,7 @@ class DNSReconApp {
} }
/** /**
* Update scan status from server * Enhanced status update with better error handling
*/ */
async updateStatus() { async updateStatus() {
try { try {
@ -406,7 +439,7 @@ class DNSReconApp {
console.log('Status response:', response); console.log('Status response:', response);
if (response.success) { if (response.success && response.status) {
const status = response.status; const status = response.status;
console.log('Current scan status:', status.status); console.log('Current scan status:', status.status);
console.log('Current progress:', status.progress_percentage + '%'); console.log('Current progress:', status.progress_percentage + '%');
@ -423,6 +456,7 @@ class DNSReconApp {
this.scanStatus = status.status; this.scanStatus = status.status;
} else { } else {
console.error('Status update failed:', response); console.error('Status update failed:', response);
// Don't show error for status updates to avoid spam
} }
} catch (error) { } catch (error) {
@ -551,7 +585,7 @@ class DNSReconApp {
} }
/** /**
* Handle status changes * Handle status changes with improved state synchronization
* @param {string} newStatus - New scan status * @param {string} newStatus - New scan status
*/ */
handleStatusChange(newStatus) { handleStatusChange(newStatus) {
@ -561,8 +595,8 @@ class DNSReconApp {
case 'running': case 'running':
this.setUIState('scanning'); this.setUIState('scanning');
this.showSuccess('Scan is running'); this.showSuccess('Scan is running');
// Reset polling frequency for active scans // Increase polling frequency for active scans
this.pollFrequency = 2000; this.startPolling(1000); // Poll every 1 second for running scans
this.updateConnectionStatus('active'); this.updateConnectionStatus('active');
break; break;
@ -598,6 +632,10 @@ class DNSReconApp {
this.stopPolling(); this.stopPolling();
this.updateConnectionStatus('idle'); this.updateConnectionStatus('idle');
break; break;
default:
console.warn(`Unknown status: ${newStatus}`);
break;
} }
} }
@ -633,8 +671,7 @@ class DNSReconApp {
} }
/** /**
* Set UI state based on scan status * Enhanced UI state management with immediate button updates
* @param {string} state - UI state
*/ */
setUIState(state) { setUIState(state) {
console.log(`Setting UI state to: ${state}`); console.log(`Setting UI state to: ${state}`);
@ -645,6 +682,7 @@ class DNSReconApp {
if (this.elements.startScan) { if (this.elements.startScan) {
this.elements.startScan.disabled = true; this.elements.startScan.disabled = true;
this.elements.startScan.classList.add('loading'); this.elements.startScan.classList.add('loading');
this.elements.startScan.innerHTML = '<span class="btn-icon">[SCANNING]</span><span>Scanning...</span>';
} }
if (this.elements.addToGraph) { if (this.elements.addToGraph) {
this.elements.addToGraph.disabled = true; this.elements.addToGraph.disabled = true;
@ -653,6 +691,7 @@ class DNSReconApp {
if (this.elements.stopScan) { if (this.elements.stopScan) {
this.elements.stopScan.disabled = false; this.elements.stopScan.disabled = false;
this.elements.stopScan.classList.remove('loading'); this.elements.stopScan.classList.remove('loading');
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>';
} }
if (this.elements.targetDomain) this.elements.targetDomain.disabled = true; if (this.elements.targetDomain) this.elements.targetDomain.disabled = true;
if (this.elements.maxDepth) this.elements.maxDepth.disabled = true; if (this.elements.maxDepth) this.elements.maxDepth.disabled = true;
@ -667,6 +706,7 @@ class DNSReconApp {
if (this.elements.startScan) { if (this.elements.startScan) {
this.elements.startScan.disabled = false; this.elements.startScan.disabled = false;
this.elements.startScan.classList.remove('loading'); this.elements.startScan.classList.remove('loading');
this.elements.startScan.innerHTML = '<span class="btn-icon">[RUN]</span><span>Start Reconnaissance</span>';
} }
if (this.elements.addToGraph) { if (this.elements.addToGraph) {
this.elements.addToGraph.disabled = false; this.elements.addToGraph.disabled = false;
@ -674,6 +714,7 @@ class DNSReconApp {
} }
if (this.elements.stopScan) { if (this.elements.stopScan) {
this.elements.stopScan.disabled = true; this.elements.stopScan.disabled = true;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>';
} }
if (this.elements.targetDomain) this.elements.targetDomain.disabled = false; if (this.elements.targetDomain) this.elements.targetDomain.disabled = false;
if (this.elements.maxDepth) this.elements.maxDepth.disabled = false; if (this.elements.maxDepth) this.elements.maxDepth.disabled = false;