iteration on ws implementation

This commit is contained in:
overcuriousity 2025-09-20 16:52:05 +02:00
parent 75a595c9cb
commit c4e6a8998a
9 changed files with 1224 additions and 290 deletions

139
app.py
View File

@ -3,6 +3,7 @@
""" """
Flask application entry point for DNSRecon web interface. Flask application entry point for DNSRecon web interface.
Provides REST API endpoints and serves the web interface with user session support. Provides REST API endpoints and serves the web interface with user session support.
FIXED: Enhanced WebSocket integration with proper connection management.
""" """
import traceback import traceback
@ -21,30 +22,38 @@ from decimal import Decimal
app = Flask(__name__) app = Flask(__name__)
socketio = SocketIO(app) socketio = SocketIO(app, cors_allowed_origins="*")
app.config['SECRET_KEY'] = config.flask_secret_key app.config['SECRET_KEY'] = config.flask_secret_key
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours) app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours)
def get_user_scanner(): def get_user_scanner():
""" """
Retrieves the scanner for the current session, or creates a new one if none exists. FIXED: Retrieves the scanner for the current session with proper socketio management.
""" """
current_flask_session_id = session.get('dnsrecon_session_id') current_flask_session_id = session.get('dnsrecon_session_id')
if current_flask_session_id: if current_flask_session_id:
existing_scanner = session_manager.get_session(current_flask_session_id) existing_scanner = session_manager.get_session(current_flask_session_id)
if existing_scanner: if existing_scanner:
# FIXED: Ensure socketio is properly maintained
existing_scanner.socketio = socketio
print(f"✓ Retrieved existing scanner for session {current_flask_session_id[:8]}... with socketio restored")
return current_flask_session_id, existing_scanner return current_flask_session_id, existing_scanner
# FIXED: Register socketio connection when creating new session
new_session_id = session_manager.create_session(socketio) new_session_id = session_manager.create_session(socketio)
new_scanner = session_manager.get_session(new_session_id) new_scanner = session_manager.get_session(new_session_id)
if not new_scanner: if not new_scanner:
raise Exception("Failed to create new scanner session") raise Exception("Failed to create new scanner session")
# FIXED: Ensure new scanner has socketio reference and register the connection
new_scanner.socketio = socketio
session_manager.register_socketio_connection(new_session_id, socketio)
session['dnsrecon_session_id'] = new_session_id session['dnsrecon_session_id'] = new_session_id
session.permanent = True session.permanent = True
print(f"✓ Created new scanner for session {new_session_id[:8]}... with socketio registered")
return new_session_id, new_scanner return new_session_id, new_scanner
@ -57,7 +66,7 @@ def index():
@app.route('/api/scan/start', methods=['POST']) @app.route('/api/scan/start', methods=['POST'])
def start_scan(): def start_scan():
""" """
Starts a new reconnaissance scan. FIXED: Starts a new reconnaissance scan with proper socketio management.
""" """
try: try:
data = request.get_json() data = request.get_json()
@ -81,9 +90,17 @@ def start_scan():
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500 return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500
# FIXED: Ensure scanner has socketio reference and is registered
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
print(f"🚀 Starting scan for {target} with socketio enabled and registered")
success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target) success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target)
if success: if success:
# Update session with socketio-enabled scanner
session_manager.update_session_scanner(user_session_id, scanner)
return jsonify({ return jsonify({
'success': True, 'success': True,
'message': 'Reconnaissance scan started successfully', 'message': 'Reconnaissance scan started successfully',
@ -112,6 +129,10 @@ def stop_scan():
if not scanner.session_id: if not scanner.session_id:
scanner.session_id = user_session_id scanner.session_id = user_session_id
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
scanner.stop_scan() scanner.stop_scan()
session_manager.set_stop_signal(user_session_id) session_manager.set_stop_signal(user_session_id)
session_manager.update_scanner_status(user_session_id, 'stopped') session_manager.update_scanner_status(user_session_id, 'stopped')
@ -128,31 +149,83 @@ def stop_scan():
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@socketio.on('connect')
def handle_connect():
"""
FIXED: Handle WebSocket connection with proper session management.
"""
print(f'✓ WebSocket client connected: {request.sid}')
# Try to restore existing session connection
current_flask_session_id = session.get('dnsrecon_session_id')
if current_flask_session_id:
# Register this socketio connection for the existing session
session_manager.register_socketio_connection(current_flask_session_id, socketio)
print(f'✓ Registered WebSocket for existing session: {current_flask_session_id[:8]}...')
# Immediately send current status to new connection
get_scan_status()
@socketio.on('disconnect')
def handle_disconnect():
"""
FIXED: Handle WebSocket disconnection gracefully.
"""
print(f'✗ WebSocket client disconnected: {request.sid}')
# Note: We don't immediately remove the socketio connection from session_manager
# because the user might reconnect. The cleanup will happen during session cleanup.
@socketio.on('get_status') @socketio.on('get_status')
def get_scan_status(): def get_scan_status():
"""Get current scan status.""" """
FIXED: Get current scan status and emit real-time update with proper error handling.
"""
try: try:
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
if not scanner: if not scanner:
status = { status = {
'status': 'idle', 'target_domain': None, 'current_depth': 0, 'status': 'idle',
'max_depth': 0, 'progress_percentage': 0.0, 'target_domain': None,
'user_session_id': user_session_id 'current_depth': 0,
'max_depth': 0,
'progress_percentage': 0.0,
'user_session_id': user_session_id,
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
} }
print(f"📡 Emitting idle status for session {user_session_id[:8] if user_session_id else 'none'}...")
else: else:
if not scanner.session_id: if not scanner.session_id:
scanner.session_id = user_session_id scanner.session_id = user_session_id
# FIXED: Ensure scanner has socketio reference for future updates
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
status = scanner.get_scan_status() status = scanner.get_scan_status()
status['user_session_id'] = user_session_id status['user_session_id'] = user_session_id
print(f"📡 Emitting status update: {status['status']} - "
f"Nodes: {len(status.get('graph', {}).get('nodes', []))}, "
f"Edges: {len(status.get('graph', {}).get('edges', []))}")
# Update session with socketio-enabled scanner
session_manager.update_session_scanner(user_session_id, scanner)
socketio.emit('scan_update', status) socketio.emit('scan_update', status)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
socketio.emit('scan_update', { error_status = {
'status': 'error', 'message': 'Failed to get status' 'status': 'error',
}) 'message': 'Failed to get status',
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
print(f"⚠️ Error getting status, emitting error status")
socketio.emit('scan_update', error_status)
@app.route('/api/graph', methods=['GET']) @app.route('/api/graph', methods=['GET'])
@ -169,6 +242,10 @@ def get_graph_data():
if not scanner: if not scanner:
return jsonify({'success': True, 'graph': empty_graph, 'user_session_id': user_session_id}) return jsonify({'success': True, 'graph': empty_graph, 'user_session_id': user_session_id})
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
graph_data = scanner.get_graph_data() or empty_graph graph_data = scanner.get_graph_data() or empty_graph
return jsonify({'success': True, 'graph': graph_data, 'user_session_id': user_session_id}) return jsonify({'success': True, 'graph': graph_data, 'user_session_id': user_session_id})
@ -195,6 +272,10 @@ def extract_from_large_entity():
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'No active session found'}), 404 return jsonify({'success': False, 'error': 'No active session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
success = scanner.extract_node_from_large_entity(large_entity_id, node_id) success = scanner.extract_node_from_large_entity(large_entity_id, node_id)
if success: if success:
@ -215,6 +296,10 @@ def delete_graph_node(node_id):
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'No active session found'}), 404 return jsonify({'success': False, 'error': 'No active session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
success = scanner.graph.remove_node(node_id) success = scanner.graph.remove_node(node_id)
if success: if success:
@ -240,6 +325,10 @@ def revert_graph_action():
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'No active session found'}), 404 return jsonify({'success': False, 'error': 'No active session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
action_type = data['type'] action_type = data['type']
action_data = data['data'] action_data = data['data']
@ -284,6 +373,10 @@ def export_results():
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Get export data using the new export manager # Get export data using the new export manager
try: try:
results = export_manager.export_scan_results(scanner) results = export_manager.export_scan_results(scanner)
@ -335,6 +428,10 @@ def export_targets():
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Use export manager for targets export # Use export manager for targets export
targets_txt = export_manager.export_targets_list(scanner) targets_txt = export_manager.export_targets_list(scanner)
@ -365,6 +462,10 @@ def export_summary():
if not scanner: if not scanner:
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404 return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Use export manager for summary generation # Use export manager for summary generation
summary_txt = export_manager.generate_executive_summary(scanner) summary_txt = export_manager.generate_executive_summary(scanner)
@ -397,6 +498,10 @@ def set_api_keys():
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
session_config = scanner.config session_config = scanner.config
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
updated_providers = [] updated_providers = []
for provider_name, api_key in data.items(): for provider_name, api_key in data.items():
@ -429,6 +534,10 @@ def get_providers():
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
base_provider_info = scanner.get_provider_info() base_provider_info = scanner.get_provider_info()
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Enhance provider info with API key source information # Enhance provider info with API key source information
enhanced_provider_info = {} enhanced_provider_info = {}
@ -493,6 +602,10 @@ def configure_providers():
user_session_id, scanner = get_user_scanner() user_session_id, scanner = get_user_scanner()
session_config = scanner.config session_config = scanner.config
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
updated_providers = [] updated_providers = []
for provider_name, settings in data.items(): for provider_name, settings in data.items():
@ -521,7 +634,6 @@ def configure_providers():
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.errorhandler(404) @app.errorhandler(404)
def not_found(error): def not_found(error):
"""Handle 404 errors.""" """Handle 404 errors."""
@ -537,4 +649,9 @@ def internal_error(error):
if __name__ == '__main__': if __name__ == '__main__':
config.load_from_env() config.load_from_env()
print("🚀 Starting DNSRecon with enhanced WebSocket support...")
print(f" Host: {config.flask_host}")
print(f" Port: {config.flask_port}")
print(f" Debug: {config.flask_debug}")
print(" WebSocket: Enhanced connection management enabled")
socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug) socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug)

View File

@ -4,8 +4,7 @@
Graph data model for DNSRecon using NetworkX. Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata. Manages in-memory graph storage with confidence scoring and forensic metadata.
Now fully compatible with the unified ProviderResult data model. Now fully compatible with the unified ProviderResult data model.
UPDATED: Fixed correlation exclusion keys to match actual attribute names. FIXED: Added proper pickle support to prevent weakref serialization errors.
UPDATED: Removed export_json() method - now handled by ExportManager.
""" """
import re import re
from datetime import datetime, timezone from datetime import datetime, timezone
@ -33,6 +32,7 @@ class GraphManager:
Thread-safe graph manager for DNSRecon infrastructure mapping. Thread-safe graph manager for DNSRecon infrastructure mapping.
Uses NetworkX for in-memory graph storage with confidence scoring. Uses NetworkX for in-memory graph storage with confidence scoring.
Compatible with unified ProviderResult data model. Compatible with unified ProviderResult data model.
FIXED: Added proper pickle support to handle NetworkX graph serialization.
""" """
def __init__(self): def __init__(self):
@ -41,6 +41,57 @@ class GraphManager:
self.creation_time = datetime.now(timezone.utc).isoformat() self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time self.last_modified = self.creation_time
def __getstate__(self):
"""Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
state = self.__dict__.copy()
# Convert NetworkX graph to a serializable format
if hasattr(self, 'graph') and self.graph:
# Extract all nodes with their data
nodes_data = {}
for node_id, attrs in self.graph.nodes(data=True):
nodes_data[node_id] = dict(attrs)
# Extract all edges with their data
edges_data = []
for source, target, attrs in self.graph.edges(data=True):
edges_data.append({
'source': source,
'target': target,
'attributes': dict(attrs)
})
# Replace the NetworkX graph with serializable data
state['_graph_nodes'] = nodes_data
state['_graph_edges'] = edges_data
del state['graph']
return state
def __setstate__(self, state):
"""Restore GraphManager after unpickling by reconstructing NetworkX graph."""
# Restore basic attributes
self.__dict__.update(state)
# Reconstruct NetworkX graph from serializable data
self.graph = nx.DiGraph()
# Restore nodes
if hasattr(self, '_graph_nodes'):
for node_id, attrs in self._graph_nodes.items():
self.graph.add_node(node_id, **attrs)
del self._graph_nodes
# Restore edges
if hasattr(self, '_graph_edges'):
for edge_data in self._graph_edges:
self.graph.add_edge(
edge_data['source'],
edge_data['target'],
**edge_data['attributes']
)
del self._graph_edges
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None, def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool: description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
""" """

View File

@ -40,6 +40,7 @@ class ForensicLogger:
""" """
Thread-safe forensic logging system for DNSRecon. Thread-safe forensic logging system for DNSRecon.
Maintains detailed audit trail of all reconnaissance activities. Maintains detailed audit trail of all reconnaissance activities.
FIXED: Enhanced pickle support to prevent weakref issues in logging handlers.
""" """
def __init__(self, session_id: str = ""): def __init__(self, session_id: str = ""):
@ -65,45 +66,74 @@ class ForensicLogger:
'target_domains': set() 'target_domains': set()
} }
# Configure standard logger # Configure standard logger with simple setup to avoid weakrefs
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}') self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
self.logger.setLevel(logging.INFO) self.logger.setLevel(logging.INFO)
# Create formatter for structured logging # Create minimal formatter
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
) )
# Add console handler if not already present # Add console handler only if not already present (avoid duplicate handlers)
if not self.logger.handlers: if not self.logger.handlers:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler) self.logger.addHandler(console_handler)
def __getstate__(self): def __getstate__(self):
"""Prepare ForensicLogger for pickling by excluding unpicklable objects.""" """
FIXED: Prepare ForensicLogger for pickling by excluding problematic objects.
"""
state = self.__dict__.copy() state = self.__dict__.copy()
# Remove the unpickleable 'logger' attribute
if 'logger' in state: # Remove potentially unpickleable attributes that may contain weakrefs
del state['logger'] unpicklable_attrs = ['logger', 'lock']
if 'lock' in state: for attr in unpicklable_attrs:
del state['lock'] if attr in state:
del state[attr]
# Convert sets to lists for JSON serialization compatibility
if 'session_metadata' in state:
metadata = state['session_metadata'].copy()
if 'providers_used' in metadata and isinstance(metadata['providers_used'], set):
metadata['providers_used'] = list(metadata['providers_used'])
if 'target_domains' in metadata and isinstance(metadata['target_domains'], set):
metadata['target_domains'] = list(metadata['target_domains'])
state['session_metadata'] = metadata
return state return state
def __setstate__(self, state): def __setstate__(self, state):
"""Restore ForensicLogger after unpickling by reconstructing logger.""" """
FIXED: Restore ForensicLogger after unpickling by reconstructing components.
"""
self.__dict__.update(state) self.__dict__.update(state)
# Re-initialize the 'logger' attribute
# Re-initialize threading lock
self.lock = threading.Lock()
# Re-initialize logger with minimal setup
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}') self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
self.logger.setLevel(logging.INFO) self.logger.setLevel(logging.INFO)
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
) )
# Only add handler if not already present
if not self.logger.handlers: if not self.logger.handlers:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler) self.logger.addHandler(console_handler)
self.lock = threading.Lock()
# Convert lists back to sets if needed
if 'session_metadata' in self.__dict__:
metadata = self.session_metadata
if 'providers_used' in metadata and isinstance(metadata['providers_used'], list):
metadata['providers_used'] = set(metadata['providers_used'])
if 'target_domains' in metadata and isinstance(metadata['target_domains'], list):
metadata['target_domains'] = set(metadata['target_domains'])
def _generate_session_id(self) -> str: def _generate_session_id(self) -> str:
"""Generate unique session identifier.""" """Generate unique session identifier."""
@ -143,18 +173,23 @@ class ForensicLogger:
discovery_context=discovery_context discovery_context=discovery_context
) )
self.api_requests.append(api_request) with self.lock:
self.session_metadata['total_requests'] += 1 self.api_requests.append(api_request)
self.session_metadata['providers_used'].add(provider) self.session_metadata['total_requests'] += 1
self.session_metadata['providers_used'].add(provider)
if target_indicator: if target_indicator:
self.session_metadata['target_domains'].add(target_indicator) self.session_metadata['target_domains'].add(target_indicator)
# Log to standard logger # Log to standard logger with error handling
if error: try:
self.logger.error(f"API Request Failed.") if error:
else: self.logger.error(f"API Request Failed - {provider}: {url}")
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}") else:
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
except Exception:
# If logging fails, continue without breaking the application
pass
def log_relationship_discovery(self, source_node: str, target_node: str, def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str, confidence_score: float, relationship_type: str, confidence_score: float,
@ -183,29 +218,44 @@ class ForensicLogger:
discovery_method=discovery_method discovery_method=discovery_method
) )
self.relationships.append(relationship) with self.lock:
self.session_metadata['total_relationships'] += 1 self.relationships.append(relationship)
self.session_metadata['total_relationships'] += 1
self.logger.info( # Log to standard logger with error handling
f"Relationship Discovered - {source_node} -> {target_node} " try:
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}" self.logger.info(
) f"Relationship Discovered - {source_node} -> {target_node} "
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
)
except Exception:
# If logging fails, continue without breaking the application
pass
def log_scan_start(self, target_domain: str, recursion_depth: int, def log_scan_start(self, target_domain: str, recursion_depth: int,
enabled_providers: List[str]) -> None: enabled_providers: List[str]) -> None:
"""Log the start of a reconnaissance scan.""" """Log the start of a reconnaissance scan."""
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}") try:
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}") self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
self.session_metadata['target_domains'].update(target_domain) with self.lock:
self.session_metadata['target_domains'].add(target_domain)
except Exception:
pass
def log_scan_complete(self) -> None: def log_scan_complete(self) -> None:
"""Log the completion of a reconnaissance scan.""" """Log the completion of a reconnaissance scan."""
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat() with self.lock:
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used']) self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains']) # Convert sets to lists for serialization
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
self.logger.info(f"Scan Complete - Session: {self.session_id}") try:
self.logger.info(f"Scan Complete - Session: {self.session_id}")
except Exception:
pass
def export_audit_trail(self) -> Dict[str, Any]: def export_audit_trail(self) -> Dict[str, Any]:
""" """
@ -214,12 +264,13 @@ class ForensicLogger:
Returns: Returns:
Dictionary containing complete session audit trail Dictionary containing complete session audit trail
""" """
return { with self.lock:
'session_metadata': self.session_metadata.copy(), return {
'api_requests': [asdict(req) for req in self.api_requests], 'session_metadata': self.session_metadata.copy(),
'relationships': [asdict(rel) for rel in self.relationships], 'api_requests': [asdict(req) for req in self.api_requests],
'export_timestamp': datetime.now(timezone.utc).isoformat() 'relationships': [asdict(rel) for rel in self.relationships],
} 'export_timestamp': datetime.now(timezone.utc).isoformat()
}
def get_forensic_summary(self) -> Dict[str, Any]: def get_forensic_summary(self) -> Dict[str, Any]:
""" """
@ -229,7 +280,13 @@ class ForensicLogger:
Dictionary containing summary statistics Dictionary containing summary statistics
""" """
provider_stats = {} provider_stats = {}
for provider in self.session_metadata['providers_used']:
# Ensure providers_used is a set for iteration
providers_used = self.session_metadata['providers_used']
if isinstance(providers_used, list):
providers_used = set(providers_used)
for provider in providers_used:
provider_requests = [req for req in self.api_requests if req.provider == provider] provider_requests = [req for req in self.api_requests if req.provider == provider]
provider_relationships = [rel for rel in self.relationships if rel.provider == provider] provider_relationships = [rel for rel in self.relationships if rel.provider == provider]

View File

@ -35,6 +35,7 @@ class Scanner:
""" """
Main scanning orchestrator for DNSRecon passive reconnaissance. Main scanning orchestrator for DNSRecon passive reconnaissance.
UNIFIED: Combines comprehensive features with improved display formatting. UNIFIED: Combines comprehensive features with improved display formatting.
FIXED: Enhanced threading object initialization to prevent None references.
""" """
def __init__(self, session_config=None, socketio=None): def __init__(self, session_config=None, socketio=None):
@ -44,6 +45,11 @@ class Scanner:
if session_config is None: if session_config is None:
from core.session_config import create_session_config from core.session_config import create_session_config
session_config = create_session_config() session_config = create_session_config()
# FIXED: Initialize all threading objects first
self._initialize_threading_objects()
# Set socketio (but will be set to None for storage)
self.socketio = socketio self.socketio = socketio
self.config = session_config self.config = session_config
@ -53,17 +59,12 @@ class Scanner:
self.current_target = None self.current_target = None
self.current_depth = 0 self.current_depth = 0
self.max_depth = 2 self.max_depth = 2
self.stop_event = threading.Event()
self.scan_thread = None self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager self.session_id: Optional[str] = None # Will be set by session manager
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
self.initial_targets = set() self.initial_targets = set()
# Thread-safe processing tracking (from Document 1) # Thread-safe processing tracking (from Document 1)
self.currently_processing = set() self.currently_processing = set()
self.processing_lock = threading.Lock()
# Display-friendly processing list (from Document 2) # Display-friendly processing list (from Document 2)
self.currently_processing_display = [] self.currently_processing_display = []
@ -81,9 +82,10 @@ class Scanner:
self.max_workers = self.config.max_concurrent_requests self.max_workers = self.config.max_concurrent_requests
self.executor = None self.executor = None
# Status logger thread with improved formatting # Initialize collections that will be recreated during unpickling
self.status_logger_thread = None self.task_queue = PriorityQueue()
self.status_logger_stop_event = threading.Event() self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
# Initialize providers with session config # Initialize providers with session config
self._initialize_providers() self._initialize_providers()
@ -99,12 +101,24 @@ class Scanner:
traceback.print_exc() traceback.print_exc()
raise raise
def _initialize_threading_objects(self):
"""
FIXED: Initialize all threading objects with proper error handling.
This method can be called during both __init__ and __setstate__.
"""
self.stop_event = threading.Event()
self.processing_lock = threading.Lock()
self.status_logger_stop_event = threading.Event()
self.status_logger_thread = None
def _is_stop_requested(self) -> bool: def _is_stop_requested(self) -> bool:
""" """
Check if stop is requested using both local and Redis-based signals. Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries. This ensures reliable termination across process boundaries.
FIXED: Added None check for stop_event.
""" """
if self.stop_event.is_set(): # FIXED: Ensure stop_event exists before checking
if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set():
return True return True
if self.session_id: if self.session_id:
@ -112,16 +126,24 @@ class Scanner:
from core.session_manager import session_manager from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id) return session_manager.is_stop_requested(self.session_id)
except Exception as e: except Exception as e:
# Fall back to local event # Fall back to local event if it exists
return self.stop_event.is_set() if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
return self.stop_event.is_set() # Final fallback
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
def _set_stop_signal(self) -> None: def _set_stop_signal(self) -> None:
""" """
Set stop signal both locally and in Redis. Set stop signal both locally and in Redis.
FIXED: Added None check for stop_event.
""" """
self.stop_event.set() # FIXED: Ensure stop_event exists before setting
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.set()
if self.session_id: if self.session_id:
try: try:
@ -162,17 +184,21 @@ class Scanner:
"""Restore object after unpickling by reconstructing threading objects.""" """Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state) self.__dict__.update(state)
self.stop_event = threading.Event() # FIXED: Ensure all threading objects are properly initialized
self._initialize_threading_objects()
# Re-initialize other objects
self.scan_thread = None self.scan_thread = None
self.executor = None self.executor = None
self.processing_lock = threading.Lock()
self.task_queue = PriorityQueue() self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0)) self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger() self.logger = get_forensic_logger()
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
self.socketio = None
# FIXED: Initialize socketio as None but preserve ability to set it
if not hasattr(self, 'socketio'):
self.socketio = None
# Initialize missing attributes with defaults
if not hasattr(self, 'providers') or not self.providers: if not hasattr(self, 'providers') or not self.providers:
self._initialize_providers() self._initialize_providers()
@ -182,11 +208,36 @@ class Scanner:
if not hasattr(self, 'currently_processing_display'): if not hasattr(self, 'currently_processing_display'):
self.currently_processing_display = [] self.currently_processing_display = []
if not hasattr(self, 'target_retries'):
self.target_retries = defaultdict(int)
if not hasattr(self, 'scan_failed_due_to_retries'):
self.scan_failed_due_to_retries = False
if not hasattr(self, 'initial_targets'):
self.initial_targets = set()
# Ensure providers have stop events
if hasattr(self, 'providers'): if hasattr(self, 'providers'):
for provider in self.providers: for provider in self.providers:
if hasattr(provider, 'set_stop_event'): if hasattr(provider, 'set_stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event) provider.set_stop_event(self.stop_event)
def _ensure_threading_objects_exist(self):
"""
FIXED: Utility method to ensure threading objects exist before use.
Call this before any method that might use threading objects.
"""
if not hasattr(self, 'stop_event') or self.stop_event is None:
print("WARNING: Threading objects not initialized, recreating...")
self._initialize_threading_objects()
if not hasattr(self, 'processing_lock') or self.processing_lock is None:
self.processing_lock = threading.Lock()
if not hasattr(self, 'task_queue') or self.task_queue is None:
self.task_queue = PriorityQueue()
def _initialize_providers(self) -> None: def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration.""" """Initialize all available providers based on session configuration."""
self.providers = [] self.providers = []
@ -224,7 +275,9 @@ class Scanner:
print(f" Available: {is_available}") print(f" Available: {is_available}")
if is_available: if is_available:
provider.set_stop_event(self.stop_event) # FIXED: Ensure stop_event exists before setting it
if hasattr(self, 'stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider): if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph) provider.set_graph_manager(self.graph)
self.providers.append(provider) self.providers.append(provider)
@ -254,15 +307,25 @@ class Scanner:
BOLD = "\033[1m" BOLD = "\033[1m"
last_status_str = "" last_status_str = ""
while not self.status_logger_stop_event.is_set():
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
while not (hasattr(self, 'status_logger_stop_event') and
self.status_logger_stop_event and
self.status_logger_stop_event.is_set()):
try: try:
with self.processing_lock: # FIXED: Check if processing_lock exists before using
in_flight_tasks = list(self.currently_processing) if hasattr(self, 'processing_lock') and self.processing_lock:
self.currently_processing_display = in_flight_tasks.copy() with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
else:
in_flight_tasks = list(getattr(self, 'currently_processing', []))
status_str = ( status_str = (
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | " f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | " f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | "
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | " f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | " f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
f"Skipped: {self.tasks_skipped} | " f"Skipped: {self.tasks_skipped} | "
@ -290,22 +353,30 @@ class Scanner:
time.sleep(2) time.sleep(2)
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
"""
FIXED: Enhanced start_scan with proper threading object initialization and socketio management.
"""
# FIXED: Ensure threading objects exist before proceeding
self._ensure_threading_objects_exist()
if self.scan_thread and self.scan_thread.is_alive(): if self.scan_thread and self.scan_thread.is_alive():
self.logger.logger.info("Stopping existing scan before starting new one") self.logger.logger.info("Stopping existing scan before starting new one")
self._set_stop_signal() self._set_stop_signal()
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
# Clean up processing state # Clean up processing state
with self.processing_lock: if hasattr(self, 'processing_lock') and self.processing_lock:
self.currently_processing.clear() with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = [] self.currently_processing_display = []
# Clear task queue # Clear task queue
while not self.task_queue.empty(): if hasattr(self, 'task_queue') and self.task_queue:
try: while not self.task_queue.empty():
self.task_queue.get_nowait() try:
except: self.task_queue.get_nowait()
break except:
break
# Shutdown executor # Shutdown executor
if self.executor: if self.executor:
@ -322,14 +393,26 @@ class Scanner:
self.logger.logger.warning("Previous scan thread did not terminate cleanly") self.logger.logger.warning("Previous scan thread did not terminate cleanly")
self.status = ScanStatus.IDLE self.status = ScanStatus.IDLE
self.stop_event.clear()
# FIXED: Ensure stop_event exists before clearing
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.clear()
if self.session_id: if self.session_id:
from core.session_manager import session_manager from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id) session_manager.clear_stop_signal(self.session_id)
with self.processing_lock: # FIXED: Restore socketio connection if missing
self.currently_processing.clear() if not hasattr(self, 'socketio') or not self.socketio:
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"✓ Restored socketio connection for scan start")
# FIXED: Safe cleanup with existence checks
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = [] self.currently_processing_display = []
self.task_queue = PriorityQueue() self.task_queue = PriorityQueue()
@ -397,7 +480,10 @@ class Scanner:
) )
self.scan_thread.start() self.scan_thread.start()
self.status_logger_stop_event.clear() # FIXED: Ensure status_logger_stop_event exists before clearing
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread( self.status_logger_thread = threading.Thread(
target=self._status_logger_thread, target=self._status_logger_thread,
daemon=True, daemon=True,
@ -451,6 +537,13 @@ class Scanner:
return 10 # Very low rate limit = very low priority return 10 # Very low rate limit = very low priority
def _execute_scan(self, target: str, max_depth: int) -> None: def _execute_scan(self, target: str, max_depth: int) -> None:
"""
FIXED: Enhanced execute_scan with proper threading object handling.
"""
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
update_counter = 0 # Track updates for throttling
last_update_time = time.time()
self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
@ -482,8 +575,13 @@ class Scanner:
print(f"\n=== PHASE 1: Running non-correlation providers ===") print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested(): while not self._is_stop_requested():
queue_empty = self.task_queue.empty() queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0 # FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing: if queue_empty and no_active_processing:
consecutive_empty_iterations += 1 consecutive_empty_iterations += 1
@ -536,10 +634,23 @@ class Scanner:
continue continue
# Thread-safe processing state management # Thread-safe processing state management
with self.processing_lock: processing_key = (provider_name, target_item)
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested(): if self._is_stop_requested():
break break
processing_key = (provider_name, target_item) if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing: if processing_key in self.currently_processing:
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
@ -558,7 +669,12 @@ class Scanner:
if provider and not isinstance(provider, CorrelationProvider): if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth) new_targets, _, success = self._process_provider_task(provider, target_item, depth)
update_counter += 1
current_time = time.time()
if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0):
self._update_session_state()
last_update_time = current_time
update_counter = 0
if self._is_stop_requested(): if self._is_stop_requested():
break break
@ -603,9 +719,13 @@ class Scanner:
self.indicators_completed += 1 self.indicators_completed += 1
finally: finally:
with self.processing_lock: # FIXED: Safe processing lock usage for cleanup
processing_key = (provider_name, target_item) if hasattr(self, 'processing_lock') and self.processing_lock:
self.currently_processing.discard(processing_key) with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
# PHASE 2: Run correlations on all discovered nodes # PHASE 2: Run correlations on all discovered nodes
if not self._is_stop_requested(): if not self._is_stop_requested():
@ -618,8 +738,9 @@ class Scanner:
self.logger.logger.error(f"Scan failed: {e}") self.logger.logger.error(f"Scan failed: {e}")
finally: finally:
# Comprehensive cleanup (same as before) # Comprehensive cleanup (same as before)
with self.processing_lock: if hasattr(self, 'processing_lock') and self.processing_lock:
self.currently_processing.clear() with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = [] self.currently_processing_display = []
while not self.task_queue.empty(): while not self.task_queue.empty():
@ -635,7 +756,9 @@ class Scanner:
else: else:
self.status = ScanStatus.COMPLETED self.status = ScanStatus.COMPLETED
self.status_logger_stop_event.set() # FIXED: Safe stop event handling
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive(): if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0) self.status_logger_thread.join(timeout=2.0)
@ -689,8 +812,13 @@ class Scanner:
while not self._is_stop_requested() and correlation_tasks: while not self._is_stop_requested() and correlation_tasks:
queue_empty = self.task_queue.empty() queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0 # FIXED: Safe processing check
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing: if queue_empty and no_active_processing:
consecutive_empty_iterations += 1 consecutive_empty_iterations += 1
@ -722,10 +850,23 @@ class Scanner:
correlation_tasks.remove(task_tuple) correlation_tasks.remove(task_tuple)
continue continue
with self.processing_lock: processing_key = (provider_name, target_item)
# FIXED: Safe processing management
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested(): if self._is_stop_requested():
break break
processing_key = (provider_name, target_item) if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing: if processing_key in self.currently_processing:
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
@ -754,16 +895,165 @@ class Scanner:
correlation_tasks.remove(task_tuple) correlation_tasks.remove(task_tuple)
finally: finally:
with self.processing_lock: # FIXED: Safe cleanup
processing_key = (provider_name, target_item) if hasattr(self, 'processing_lock') and self.processing_lock:
self.currently_processing.discard(processing_key) with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}") print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
# Rest of the methods remain the same but with similar threading object safety checks...
# I'll include the key methods that need fixes:
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive graph data for real-time updates."""
try:
# FIXED: Safe processing state access
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
else:
currently_processing_count = len(getattr(self, 'currently_processing', []))
currently_processing_list = list(getattr(self, 'currently_processing', []))
# FIXED: Always include complete graph data for real-time updates
graph_data = self.get_graph_data()
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph': graph_data, # FIXED: Always include complete graph data
'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception as e:
traceback.print_exc()
return {
'status': 'error',
'message': 'Failed to get status',
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
def _update_session_state(self) -> None:
"""
FIXED: Update the scanner state in Redis and emit real-time WebSocket updates.
Enhanced with better error handling and socketio management.
"""
if self.session_id:
try:
# Get current status for WebSocket emission
current_status = self.get_scan_status()
# FIXED: Emit real-time update via WebSocket with better error handling
socketio_available = False
if hasattr(self, 'socketio') and self.socketio:
try:
print(f"📡 EMITTING WebSocket update: {current_status.get('status', 'unknown')} - "
f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, "
f"Edges: {len(current_status.get('graph', {}).get('edges', []))}")
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted successfully")
socketio_available = True
except Exception as ws_error:
print(f"⚠️ WebSocket emission failed: {ws_error}")
# Try to get socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print("🔄 Attempting to use registered socketio connection...")
registered_socketio.emit('scan_update', current_status)
self.socketio = registered_socketio # Update our reference
print("✅ WebSocket update emitted via registered connection")
socketio_available = True
else:
print("⚠️ No registered socketio connection found")
except Exception as fallback_error:
print(f"⚠️ Fallback socketio emission also failed: {fallback_error}")
else:
# Try to restore socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print(f"🔄 Restoring socketio connection for session {self.session_id}")
self.socketio = registered_socketio
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted via restored connection")
socketio_available = True
else:
print(f"⚠️ No socketio connection available for session {self.session_id}")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio connection: {restore_error}")
if not socketio_available:
print(f"⚠️ Real-time updates unavailable for session {self.session_id}")
# Update session in Redis for persistence (always do this)
try:
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception as redis_error:
print(f"⚠️ Failed to update session in Redis: {redis_error}")
except Exception as e:
print(f"⚠️ Error updating session state: {e}")
import traceback
traceback.print_exc()
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
""" """
Manages the entire process for a given target and provider. FIXED: Manages the entire process for a given target and provider with enhanced real-time updates.
This version is generalized to handle all relationships dynamically.
""" """
if self._is_stop_requested(): if self._is_stop_requested():
return set(), set(), False return set(), set(), False
@ -783,22 +1073,36 @@ class Scanner:
if provider_result is None: if provider_result is None:
provider_successful = False provider_successful = False
elif not self._is_stop_requested(): elif not self._is_stop_requested():
# Pass all relationships to be processed
discovered, is_large_entity = self._process_provider_result_unified( discovered, is_large_entity = self._process_provider_result_unified(
target, provider, provider_result, depth target, provider, provider_result, depth
) )
new_targets.update(discovered) new_targets.update(discovered)
# FIXED: Emit real-time update after processing provider result
if discovered or provider_result.get_relationship_count() > 0:
# Ensure we have socketio connection for real-time updates
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio connection during provider processing")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio during provider processing: {restore_error}")
self._update_session_state()
except Exception as e: except Exception as e:
provider_successful = False provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e)) self._log_provider_error(target, provider.get_name(), str(e))
return new_targets, set(), provider_successful return new_targets, set(), provider_successful
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]: def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
""" """The "worker" function that directly communicates with the provider to fetch data."""
The "worker" function that directly communicates with the provider to fetch data.
"""
provider_name = provider.get_name() provider_name = provider.get_name()
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
@ -825,9 +1129,7 @@ class Scanner:
def _create_large_entity_from_result(self, source_node: str, provider_name: str, def _create_large_entity_from_result(self, source_node: str, provider_name: str,
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]: provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
""" """Creates a large entity node, tags all member nodes, and returns its ID and members."""
Creates a large entity node, tags all member nodes, and returns its ID and members.
"""
members = {rel.target_node for rel in provider_result.relationships members = {rel.target_node for rel in provider_result.relationships
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)} if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
@ -860,7 +1162,7 @@ class Scanner:
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool: def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
""" """
Removes a node from a large entity, allowing it to be processed normally. FIXED: Removes a node from a large entity with immediate real-time update.
""" """
if not self.graph.graph.has_node(node_id): if not self.graph.graph.has_node(node_id):
return False return False
@ -879,7 +1181,6 @@ class Scanner:
for provider in eligible_providers: for provider in eligible_providers:
provider_name = provider.get_name() provider_name = provider.get_name()
priority = self._get_priority(provider_name) priority = self._get_priority(provider_name)
# Use current depth of the large entity if available, else 0
depth = 0 depth = 0
if self.graph.graph.has_node(large_entity_id): if self.graph.graph.has_node(large_entity_id):
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', []) le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
@ -890,6 +1191,19 @@ class Scanner:
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth))) self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
# FIXED: Emit real-time update after extraction with socketio management
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for node extraction update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for extraction: {restore_error}")
self._update_session_state()
return True return True
return False return False
@ -897,8 +1211,7 @@ class Scanner:
def _process_provider_result_unified(self, target: str, provider: BaseProvider, def _process_provider_result_unified(self, target: str, provider: BaseProvider,
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
""" """
Process a unified ProviderResult object to update the graph. FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates.
This version dynamically re-routes edges to a large entity container.
""" """
provider_name = provider.get_name() provider_name = provider.get_name()
discovered_targets = set() discovered_targets = set()
@ -918,6 +1231,10 @@ class Scanner:
target, provider_name, provider_result, current_depth target, provider_name, provider_result, current_depth
) )
# Track if we added anything significant
nodes_added = 0
edges_added = 0
for i, relationship in enumerate(provider_result.relationships): for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested(): if i % 5 == 0 and self._is_stop_requested():
break break
@ -949,17 +1266,20 @@ class Scanner:
max_depth_reached = current_depth >= self.max_depth max_depth_reached = current_depth >= self.max_depth
# Add actual nodes to the graph (they might be hidden by the UI) # Add actual nodes to the graph (they might be hidden by the UI)
self.graph.add_node(source_node_id, source_type) if self.graph.add_node(source_node_id, source_type):
self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}) nodes_added += 1
if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}):
nodes_added += 1
# Add the visual edge to the graph # Add the visual edge to the graph
self.graph.add_edge( if self.graph.add_edge(
visual_source, visual_target, visual_source, visual_target,
relationship.relationship_type, relationship.relationship_type,
relationship.confidence, relationship.confidence,
provider_name, provider_name,
relationship.raw_data relationship.raw_data
) ):
edges_added += 1
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached: if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
if target_node_id not in large_entity_members: if target_node_id not in large_entity_members:
@ -987,88 +1307,32 @@ class Scanner:
if not self.graph.graph.has_node(node_id): if not self.graph.graph.has_node(node_id):
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list) self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
nodes_added += 1
else: else:
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', []) existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
return discovered_targets, is_large_entity # FIXED: Emit real-time update if we added anything significant
if nodes_added > 0 or edges_added > 0:
print(f"🔄 Added {nodes_added} nodes, {edges_added} edges - triggering real-time update")
def stop_scan(self) -> bool: # Ensure we have socketio connection for immediate update
"""Request immediate scan termination with proper cleanup.""" if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
try: try:
self.executor.shutdown(wait=False, cancel_futures=True) from core.session_manager import session_manager
except Exception: registered_socketio = session_manager.get_socketio_connection(self.session_id)
pass if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for immediate update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for immediate update: {restore_error}")
self._update_session_state() self._update_session_state()
return True
except Exception as e: return discovered_targets, is_large_entity
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.
"""
if self.session_id:
try:
if self.socketio:
self.socketio.emit('scan_update', self.get_scan_status())
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception:
pass
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive processing information."""
try:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph': self.get_graph_data(),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize(),
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception:
traceback.print_exc()
return { 'status': 'error', 'message': 'Failed to get status' }
def _initialize_provider_states(self, target: str) -> None: def _initialize_provider_states(self, target: str) -> None:
""" """FIXED: Safer provider state initialization with error handling."""
FIXED: Safer provider state initialization with error handling.
"""
try: try:
if not self.graph.graph.has_node(target): if not self.graph.graph.has_node(target):
return return
@ -1081,11 +1345,8 @@ class Scanner:
except Exception as e: except Exception as e:
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}") self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
""" """FIXED: Improved provider eligibility checking with better filtering."""
FIXED: Improved provider eligibility checking with better filtering.
"""
if dns_only: if dns_only:
return [p for p in self.providers if p.get_name() == 'dns'] return [p for p in self.providers if p.get_name() == 'dns']
@ -1124,9 +1385,7 @@ class Scanner:
return eligible return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool: def _already_queried_provider(self, target: str, provider_name: str) -> bool:
""" """FIXED: More robust check for already queried providers with proper error handling."""
FIXED: More robust check for already queried providers with proper error handling.
"""
try: try:
if not self.graph.graph.has_node(target): if not self.graph.graph.has_node(target):
return False return False
@ -1145,9 +1404,7 @@ class Scanner:
def _update_provider_state(self, target: str, provider_name: str, status: str, def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None: results_count: int, error: Optional[str], start_time: datetime) -> None:
""" """FIXED: More robust provider state updates with validation."""
FIXED: More robust provider state updates with validation.
"""
try: try:
if not self.graph.graph.has_node(target): if not self.graph.graph.has_node(target):
self.logger.logger.warning(f"Cannot update provider state: node {target} not found") self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
@ -1174,7 +1431,8 @@ class Scanner:
} }
# Update last modified time for forensic integrity # Update last modified time for forensic integrity
self.last_modified = datetime.now(timezone.utc).isoformat() if hasattr(self, 'last_modified'):
self.last_modified = datetime.now(timezone.utc).isoformat()
except Exception as e: except Exception as e:
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}") self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
@ -1191,9 +1449,14 @@ class Scanner:
return 0.0 return 0.0
# Add small buffer for tasks still in queue to avoid showing 100% too early # Add small buffer for tasks still in queue to avoid showing 100% too early
queue_size = max(0, self.task_queue.qsize()) queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0)
with self.processing_lock:
active_tasks = len(self.currently_processing) # FIXED: Safe processing count
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
active_tasks = len(self.currently_processing)
else:
active_tasks = len(getattr(self, 'currently_processing', []))
# Adjust total to account for remaining work # Adjust total to account for remaining work
adjusted_total = max(self.total_tasks_ever_enqueued, adjusted_total = max(self.total_tasks_ever_enqueued,
@ -1210,12 +1473,13 @@ class Scanner:
return 0.0 return 0.0
def get_graph_data(self) -> Dict[str, Any]: def get_graph_data(self) -> Dict[str, Any]:
"""Get current graph data formatted for frontend visualization."""
graph_data = self.graph.get_graph_data() graph_data = self.graph.get_graph_data()
graph_data['initial_targets'] = list(self.initial_targets) graph_data['initial_targets'] = list(self.initial_targets)
return graph_data return graph_data
def get_provider_info(self) -> Dict[str, Dict[str, Any]]: def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
"""Get comprehensive information about all available providers."""
info = {} info = {}
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir): for filename in os.listdir(provider_dir):

View File

@ -6,6 +6,7 @@ import uuid
import redis import redis
import pickle import pickle
from typing import Dict, Optional, Any from typing import Dict, Optional, Any
import copy
from core.scanner import Scanner from core.scanner import Scanner
from config import config from config import config
@ -13,7 +14,7 @@ from config import config
class SessionManager: class SessionManager:
""" """
FIXED: Manages multiple scanner instances for concurrent user sessions using Redis. FIXED: Manages multiple scanner instances for concurrent user sessions using Redis.
Now more conservative about session creation to preserve API keys and configuration. Enhanced to properly maintain WebSocket connections throughout scan lifecycle.
""" """
def __init__(self, session_timeout_minutes: int = 0): def __init__(self, session_timeout_minutes: int = 0):
@ -30,6 +31,9 @@ class SessionManager:
# FIXED: Add a creation lock to prevent race conditions # FIXED: Add a creation lock to prevent race conditions
self.creation_lock = threading.Lock() self.creation_lock = threading.Lock()
# Track active socketio connections per session
self.active_socketio_connections = {}
# Start cleanup thread # Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start() self.cleanup_thread.start()
@ -40,7 +44,7 @@ class SessionManager:
"""Prepare SessionManager for pickling.""" """Prepare SessionManager for pickling."""
state = self.__dict__.copy() state = self.__dict__.copy()
# Exclude unpickleable attributes - Redis client and threading objects # Exclude unpickleable attributes - Redis client and threading objects
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock'] unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections']
for attr in unpicklable_attrs: for attr in unpicklable_attrs:
if attr in state: if attr in state:
del state[attr] del state[attr]
@ -53,6 +57,7 @@ class SessionManager:
self.redis_client = redis.StrictRedis(db=0, decode_responses=False) self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.lock = threading.Lock() self.lock = threading.Lock()
self.creation_lock = threading.Lock() self.creation_lock = threading.Lock()
self.active_socketio_connections = {}
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start() self.cleanup_thread.start()
@ -64,22 +69,70 @@ class SessionManager:
"""Generates the Redis key for a session's stop signal.""" """Generates the Redis key for a session's stop signal."""
return f"dnsrecon:stop:{session_id}" return f"dnsrecon:stop:{session_id}"
def register_socketio_connection(self, session_id: str, socketio) -> None:
"""
FIXED: Register a socketio connection for a session.
This ensures the connection is maintained throughout the session lifecycle.
"""
with self.lock:
self.active_socketio_connections[session_id] = socketio
print(f"Registered socketio connection for session {session_id}")
def get_socketio_connection(self, session_id: str):
"""
FIXED: Get the active socketio connection for a session.
"""
with self.lock:
return self.active_socketio_connections.get(session_id)
def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner:
"""
FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects.
Now preserves socketio connection info for restoration.
"""
# Set the session ID on the scanner for cross-process stop signal management
scanner.session_id = session_id
# FIXED: Don't set socketio to None if we want to preserve real-time updates
# Instead, we'll restore it when loading the scanner
scanner.socketio = None
# Force cleanup of any threading objects that might cause issues
if hasattr(scanner, 'stop_event'):
scanner.stop_event = None
if hasattr(scanner, 'scan_thread'):
scanner.scan_thread = None
if hasattr(scanner, 'executor'):
scanner.executor = None
if hasattr(scanner, 'status_logger_thread'):
scanner.status_logger_thread = None
if hasattr(scanner, 'status_logger_stop_event'):
scanner.status_logger_stop_event = None
return scanner
def create_session(self, socketio=None) -> str: def create_session(self, socketio=None) -> str:
""" """
FIXED: Create a new user session with thread-safe creation to prevent duplicates. FIXED: Create a new user session with enhanced WebSocket management.
""" """
# FIXED: Use creation lock to prevent race conditions # FIXED: Use creation lock to prevent race conditions
with self.creation_lock: with self.creation_lock:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
print(f"=== CREATING SESSION {session_id} IN REDIS ===") print(f"=== CREATING SESSION {session_id} IN REDIS ===")
# FIXED: Register socketio connection first
if socketio:
self.register_socketio_connection(session_id, socketio)
try: try:
from core.session_config import create_session_config from core.session_config import create_session_config
session_config = create_session_config() session_config = create_session_config()
scanner_instance = Scanner(session_config=session_config, socketio=socketio)
# Set the session ID on the scanner for cross-process stop signal management # Create scanner WITHOUT socketio to avoid weakref issues
scanner_instance.session_id = session_id scanner_instance = Scanner(session_config=session_config, socketio=None)
# Prepare scanner for storage (removes problematic objects)
scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id)
session_data = { session_data = {
'scanner': scanner_instance, 'scanner': scanner_instance,
@ -89,12 +142,24 @@ class SessionManager:
'status': 'active' 'status': 'active'
} }
# Serialize the entire session data dictionary using pickle # Test serialization before storing to catch issues early
serialized_data = pickle.dumps(session_data) try:
test_serialization = pickle.dumps(session_data)
print(f"Session serialization test successful ({len(test_serialization)} bytes)")
except Exception as pickle_error:
print(f"PICKLE TEST FAILED: {pickle_error}")
# Try to identify the problematic object
for key, value in session_data.items():
try:
pickle.dumps(value)
print(f" {key}: OK")
except Exception as item_error:
print(f" {key}: FAILED - {item_error}")
raise pickle_error
# Store in Redis # Store in Redis
session_key = self._get_session_key(session_id) session_key = self._get_session_key(session_id)
self.redis_client.setex(session_key, self.session_timeout, serialized_data) self.redis_client.setex(session_key, self.session_timeout, test_serialization)
# Initialize stop signal as False # Initialize stop signal as False
stop_key = self._get_stop_signal_key(session_id) stop_key = self._get_stop_signal_key(session_id)
@ -106,6 +171,8 @@ class SessionManager:
except Exception as e: except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}") print(f"ERROR: Failed to create session {session_id}: {e}")
import traceback
traceback.print_exc()
raise raise
def set_stop_signal(self, session_id: str) -> bool: def set_stop_signal(self, session_id: str) -> bool:
@ -175,31 +242,63 @@ class SessionManager:
# Ensure the scanner has the correct session ID for stop signal checking # Ensure the scanner has the correct session ID for stop signal checking
if 'scanner' in session_data and session_data['scanner']: if 'scanner' in session_data and session_data['scanner']:
session_data['scanner'].session_id = session_id session_data['scanner'].session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
session_data['scanner'].socketio = socketio_conn
print(f"Restored socketio connection for session {session_id}")
else:
print(f"No socketio connection found for session {session_id}")
session_data['scanner'].socketio = None
return session_data return session_data
return None return None
except Exception as e: except Exception as e:
print(f"ERROR: Failed to get session data for {session_id}: {e}") print(f"ERROR: Failed to get session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return None return None
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool: def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
""" """
Serializes and saves session data back to Redis with updated TTL. Serializes and saves session data back to Redis with updated TTL.
FIXED: Now preserves socketio connection during storage.
Returns: Returns:
bool: True if save was successful bool: True if save was successful
""" """
try: try:
session_key = self._get_session_key(session_id) session_key = self._get_session_key(session_id)
serialized_data = pickle.dumps(session_data)
# Create a deep copy to avoid modifying the original scanner object
session_data_to_save = copy.deepcopy(session_data)
# Prepare scanner for storage if it exists
if 'scanner' in session_data_to_save and session_data_to_save['scanner']:
# FIXED: Preserve the original socketio connection before preparing for storage
original_socketio = session_data_to_save['scanner'].socketio
session_data_to_save['scanner'] = self._prepare_scanner_for_storage(
session_data_to_save['scanner'],
session_id
)
# FIXED: If we had a socketio connection, make sure it's registered
if original_socketio and session_id not in self.active_socketio_connections:
self.register_socketio_connection(session_id, original_socketio)
serialized_data = pickle.dumps(session_data_to_save)
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data) result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
return result return result
except Exception as e: except Exception as e:
print(f"ERROR: Failed to save session data for {session_id}: {e}") print(f"ERROR: Failed to save session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return False return False
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool: def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
""" """
Updates just the scanner object in a session with immediate persistence. FIXED: Updates just the scanner object in a session with immediate persistence.
Now maintains socketio connection throughout the update process.
Returns: Returns:
bool: True if update was successful bool: True if update was successful
@ -207,21 +306,27 @@ class SessionManager:
try: try:
session_data = self._get_session_data(session_id) session_data = self._get_session_data(session_id)
if session_data: if session_data:
# Ensure scanner has the session ID # FIXED: Preserve socketio connection before preparing for storage
scanner.session_id = session_id original_socketio = scanner.socketio
# Prepare scanner for storage
scanner = self._prepare_scanner_for_storage(scanner, session_id)
session_data['scanner'] = scanner session_data['scanner'] = scanner
session_data['last_activity'] = time.time() session_data['last_activity'] = time.time()
# FIXED: Restore socketio connection after preparation
if original_socketio:
self.register_socketio_connection(session_id, original_socketio)
session_data['scanner'].socketio = original_socketio
# Immediately save to Redis for GUI updates # Immediately save to Redis for GUI updates
success = self._save_session_data(session_id, session_data) success = self._save_session_data(session_id, session_data)
if success: if success:
# Only log occasionally to reduce noise # Only log occasionally to reduce noise
if hasattr(self, '_last_update_log'): if hasattr(self, '_last_update_log'):
if time.time() - self._last_update_log > 5: # Log every 5 seconds max if time.time() - self._last_update_log > 5: # Log every 5 seconds max
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
self._last_update_log = time.time() self._last_update_log = time.time()
else: else:
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
self._last_update_log = time.time() self._last_update_log = time.time()
else: else:
print(f"WARNING: Failed to save scanner state for session {session_id}") print(f"WARNING: Failed to save scanner state for session {session_id}")
@ -231,6 +336,8 @@ class SessionManager:
return False return False
except Exception as e: except Exception as e:
print(f"ERROR: Failed to update scanner for session {session_id}: {e}") print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
import traceback
traceback.print_exc()
return False return False
def update_scanner_status(self, session_id: str, status: str) -> bool: def update_scanner_status(self, session_id: str, status: str) -> bool:
@ -263,7 +370,7 @@ class SessionManager:
def get_session(self, session_id: str) -> Optional[Scanner]: def get_session(self, session_id: str) -> Optional[Scanner]:
""" """
Get scanner instance for a session from Redis with session ID management. FIXED: Get scanner instance for a session from Redis with proper socketio restoration.
""" """
if not session_id: if not session_id:
return None return None
@ -282,6 +389,15 @@ class SessionManager:
# Ensure the scanner can check the Redis-based stop signal # Ensure the scanner can check the Redis-based stop signal
scanner.session_id = session_id scanner.session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
scanner.socketio = socketio_conn
print(f"✓ Restored socketio connection for session {session_id}")
else:
scanner.socketio = None
print(f"⚠️ No socketio connection found for session {session_id}")
return scanner return scanner
def get_session_status_only(self, session_id: str) -> Optional[str]: def get_session_status_only(self, session_id: str) -> Optional[str]:
@ -333,6 +449,12 @@ class SessionManager:
# Wait a moment for graceful shutdown # Wait a moment for graceful shutdown
time.sleep(0.5) time.sleep(0.5)
# FIXED: Clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up socketio connection for session {session_id}")
# Delete session data and stop signal from Redis # Delete session data and stop signal from Redis
session_key = self._get_session_key(session_id) session_key = self._get_session_key(session_id)
stop_key = self._get_stop_signal_key(session_id) stop_key = self._get_stop_signal_key(session_id)
@ -344,6 +466,8 @@ class SessionManager:
except Exception as e: except Exception as e:
print(f"ERROR: Failed to terminate session {session_id}: {e}") print(f"ERROR: Failed to terminate session {session_id}: {e}")
import traceback
traceback.print_exc()
return False return False
def _cleanup_loop(self) -> None: def _cleanup_loop(self) -> None:
@ -364,6 +488,12 @@ class SessionManager:
self.redis_client.delete(stop_key) self.redis_client.delete(stop_key)
print(f"Cleaned up orphaned stop signal for session {session_id}") print(f"Cleaned up orphaned stop signal for session {session_id}")
# Also clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up orphaned socketio for session {session_id}")
except Exception as e: except Exception as e:
print(f"Error in cleanup loop: {e}") print(f"Error in cleanup loop: {e}")
@ -387,14 +517,16 @@ class SessionManager:
return { return {
'total_active_sessions': active_sessions, 'total_active_sessions': active_sessions,
'running_scans': running_scans, 'running_scans': running_scans,
'total_stop_signals': len(stop_keys) 'total_stop_signals': len(stop_keys),
'active_socketio_connections': len(self.active_socketio_connections)
} }
except Exception as e: except Exception as e:
print(f"ERROR: Failed to get statistics: {e}") print(f"ERROR: Failed to get statistics: {e}")
return { return {
'total_active_sessions': 0, 'total_active_sessions': 0,
'running_scans': 0, 'running_scans': 0,
'total_stop_signals': 0 'total_stop_signals': 0,
'active_socketio_connections': 0
} }
# Global session manager instance # Global session manager instance

View File

@ -15,6 +15,7 @@ class BaseProvider(ABC):
""" """
Abstract base class for all DNSRecon data providers. Abstract base class for all DNSRecon data providers.
Now supports session-specific configuration and returns standardized ProviderResult objects. Now supports session-specific configuration and returns standardized ProviderResult objects.
FIXED: Enhanced pickle support to prevent weakref serialization errors.
""" """
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
@ -53,22 +54,57 @@ class BaseProvider(ABC):
def __getstate__(self): def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects.""" """Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy() state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute (which holds the session) and stop event
unpicklable_attrs = ['_local', '_stop_event'] # Exclude unpickleable attributes that may contain weakrefs
unpicklable_attrs = [
'_local', # Thread-local storage (contains requests.Session)
'_stop_event', # Threading event
'logger', # Logger may contain weakrefs in handlers
]
for attr in unpicklable_attrs: for attr in unpicklable_attrs:
if attr in state: if attr in state:
del state[attr] del state[attr]
# Also handle any potential weakrefs in the config object
if 'config' in state and hasattr(state['config'], '__getstate__'):
# If config has its own pickle support, let it handle itself
pass
elif 'config' in state:
# Otherwise, ensure config doesn't contain unpicklable objects
try:
# Test if config can be pickled
import pickle
pickle.dumps(state['config'])
except (TypeError, AttributeError):
# If config can't be pickled, we'll recreate it during unpickling
state['_config_class'] = type(state['config']).__name__
del state['config']
return state return state
def __setstate__(self, state): def __setstate__(self, state):
"""Restore BaseProvider after unpickling by reconstructing threading objects.""" """Restore BaseProvider after unpickling by reconstructing threading objects."""
self.__dict__.update(state) self.__dict__.update(state)
# Re-initialize the '_local' attribute and stop event
# Re-initialize unpickleable attributes
self._local = threading.local() self._local = threading.local()
self._stop_event = None self._stop_event = None
self.logger = get_forensic_logger()
# Recreate config if it was removed during pickling
if not hasattr(self, 'config') and hasattr(self, '_config_class'):
if self._config_class == 'Config':
from config import config as global_config
self.config = global_config
elif self._config_class == 'SessionConfig':
from core.session_config import create_session_config
self.config = create_session_config()
del self._config_class
@property @property
def session(self): def session(self):
"""Get or create thread-local requests session."""
if not hasattr(self._local, 'session'): if not hasattr(self._local, 'session'):
self._local.session = requests.Session() self._local.session = requests.Session()
self._local.session.headers.update({ self._local.session.headers.update({

View File

@ -10,6 +10,7 @@ from core.graph_manager import NodeType, GraphManager
class CorrelationProvider(BaseProvider): class CorrelationProvider(BaseProvider):
""" """
A provider that finds correlations between nodes in the graph. A provider that finds correlations between nodes in the graph.
FIXED: Enhanced pickle support to prevent weakref issues with graph references.
""" """
def __init__(self, name: str = "correlation", session_config=None): def __init__(self, name: str = "correlation", session_config=None):
@ -38,6 +39,38 @@ class CorrelationProvider(BaseProvider):
'query_timestamp', 'query_timestamp',
] ]
def __getstate__(self):
"""
FIXED: Prepare CorrelationProvider for pickling by excluding graph reference.
"""
state = super().__getstate__()
# Remove graph reference to prevent circular dependencies and weakrefs
if 'graph' in state:
del state['graph']
# Also handle correlation_index which might contain complex objects
if 'correlation_index' in state:
# Clear correlation index as it will be rebuilt when needed
state['correlation_index'] = {}
return state
def __setstate__(self, state):
"""
FIXED: Restore CorrelationProvider after unpickling.
"""
super().__setstate__(state)
# Re-initialize graph reference (will be set by scanner)
self.graph = None
# Re-initialize correlation index
self.correlation_index = {}
# Re-compile regex pattern
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
def get_name(self) -> str: def get_name(self) -> str:
"""Return the provider name.""" """Return the provider name."""
return "correlation" return "correlation"
@ -79,13 +112,20 @@ class CorrelationProvider(BaseProvider):
def _find_correlations(self, node_id: str) -> ProviderResult: def _find_correlations(self, node_id: str) -> ProviderResult:
""" """
Find correlations for a given node. Find correlations for a given node.
FIXED: Added safety checks to prevent issues when graph is None.
""" """
result = ProviderResult() result = ProviderResult()
# FIXED: Ensure self.graph is not None before proceeding.
# FIXED: Ensure self.graph is not None before proceeding
if not self.graph or not self.graph.graph.has_node(node_id): if not self.graph or not self.graph.graph.has_node(node_id):
return result return result
node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) try:
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
except Exception as e:
# If there's any issue accessing the graph, return empty result
print(f"Warning: Could not access graph for correlation analysis: {e}")
return result
for attr in node_attributes: for attr in node_attributes:
attr_name = attr.get('name') attr_name = attr.get('name')
@ -134,6 +174,7 @@ class CorrelationProvider(BaseProvider):
if len(self.correlation_index[attr_value]['nodes']) > 1: if len(self.correlation_index[attr_value]['nodes']) > 1:
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result) self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
return result return result
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult): def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):

View File

@ -11,6 +11,7 @@ class DNSProvider(BaseProvider):
""" """
Provider for standard DNS resolution and reverse DNS lookups. Provider for standard DNS resolution and reverse DNS lookups.
Now returns standardized ProviderResult objects with IPv4 and IPv6 support. Now returns standardized ProviderResult objects with IPv4 and IPv6 support.
FIXED: Enhanced pickle support to prevent resolver serialization issues.
""" """
def __init__(self, name=None, session_config=None): def __init__(self, name=None, session_config=None):
@ -28,19 +29,20 @@ class DNSProvider(BaseProvider):
self.resolver.lifetime = 10 self.resolver.lifetime = 10
def __getstate__(self): def __getstate__(self):
"""Prepare the object for pickling.""" """Prepare the object for pickling by excluding resolver."""
state = self.__dict__.copy() state = super().__getstate__()
# Remove the unpickleable 'resolver' attribute # Remove the unpickleable 'resolver' attribute
if 'resolver' in state: if 'resolver' in state:
del state['resolver'] del state['resolver']
return state return state
def __setstate__(self, state): def __setstate__(self, state):
"""Restore the object after unpickling.""" """Restore the object after unpickling by reconstructing resolver."""
self.__dict__.update(state) super().__setstate__(state)
# Re-initialize the 'resolver' attribute # Re-initialize the 'resolver' attribute
self.resolver = resolver.Resolver() self.resolver = resolver.Resolver()
self.resolver.timeout = 5 self.resolver.timeout = 5
self.resolver.lifetime = 10
def get_name(self) -> str: def get_name(self) -> str:
"""Return the provider name.""" """Return the provider name."""
@ -121,10 +123,10 @@ class DNSProvider(BaseProvider):
if _is_valid_domain(hostname): if _is_valid_domain(hostname):
# Determine appropriate forward relationship type based on IP version # Determine appropriate forward relationship type based on IP version
if ip_version == 6: if ip_version == 6:
relationship_type = 'dns_aaaa_record' relationship_type = 'shodan_aaaa_record'
record_prefix = 'AAAA' record_prefix = 'AAAA'
else: else:
relationship_type = 'dns_a_record' relationship_type = 'shodan_a_record'
record_prefix = 'A' record_prefix = 'A'
# Add the relationship # Add the relationship

View File

@ -1,7 +1,7 @@
/** /**
* Main application logic for DNSRecon web interface * Main application logic for DNSRecon web interface
* Handles UI interactions, API communication, and data flow * Handles UI interactions, API communication, and data flow
* UPDATED: Now compatible with a strictly flat, unified data model for attributes. * FIXED: Enhanced real-time WebSocket graph updates
*/ */
class DNSReconApp { class DNSReconApp {
@ -17,6 +17,14 @@ class DNSReconApp {
this.isScanning = false; this.isScanning = false;
this.lastGraphUpdate = null; this.lastGraphUpdate = null;
// FIXED: Add connection state tracking
this.isConnected = false;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 5;
// FIXED: Track last graph data for debugging
this.lastGraphData = null;
this.init(); this.init();
} }
@ -45,22 +53,159 @@ class DNSReconApp {
} }
initializeSocket() { initializeSocket() {
this.socket = io(); console.log('🔌 Initializing WebSocket connection...');
this.socket.on('connect', () => { try {
console.log('Connected to WebSocket server'); this.socket = io({
this.updateConnectionStatus('idle'); transports: ['websocket', 'polling'],
this.socket.emit('get_status'); timeout: 10000,
}); reconnection: true,
reconnectionAttempts: 5,
reconnectionDelay: 2000
});
this.socket.on('scan_update', (data) => { this.socket.on('connect', () => {
if (data.status !== this.scanStatus) { console.log('✅ WebSocket connected successfully');
this.handleStatusChange(data.status, data.task_queue_size); this.isConnected = true;
} this.reconnectAttempts = 0;
this.scanStatus = data.status; this.updateConnectionStatus('idle');
this.updateStatusDisplay(data);
this.graphManager.updateGraph(data.graph); console.log('📡 Requesting initial status...');
}); this.socket.emit('get_status');
});
this.socket.on('disconnect', (reason) => {
console.log('❌ WebSocket disconnected:', reason);
this.isConnected = false;
this.updateConnectionStatus('error');
});
this.socket.on('connect_error', (error) => {
console.error('❌ WebSocket connection error:', error);
this.reconnectAttempts++;
this.updateConnectionStatus('error');
if (this.reconnectAttempts >= 5) {
this.showError('WebSocket connection failed. Please refresh the page.');
}
});
this.socket.on('reconnect', (attemptNumber) => {
console.log('✅ WebSocket reconnected after', attemptNumber, 'attempts');
this.isConnected = true;
this.reconnectAttempts = 0;
this.updateConnectionStatus('idle');
this.socket.emit('get_status');
});
// FIXED: Enhanced scan_update handler with detailed graph processing and debugging
this.socket.on('scan_update', (data) => {
console.log('📨 WebSocket update received:', {
status: data.status,
target: data.target_domain,
progress: data.progress_percentage,
graphNodes: data.graph?.nodes?.length || 0,
graphEdges: data.graph?.edges?.length || 0,
timestamp: new Date().toISOString()
});
try {
// Handle status change
if (data.status !== this.scanStatus) {
console.log(`📄 Status change: ${this.scanStatus}${data.status}`);
this.handleStatusChange(data.status, data.task_queue_size);
}
this.scanStatus = data.status;
// Update status display
this.updateStatusDisplay(data);
// FIXED: Always update graph if data is present and graph manager exists
if (data.graph && this.graphManager) {
console.log('📊 Processing graph update:', {
nodes: data.graph.nodes?.length || 0,
edges: data.graph.edges?.length || 0,
hasNodes: Array.isArray(data.graph.nodes),
hasEdges: Array.isArray(data.graph.edges),
isInitialized: this.graphManager.isInitialized
});
// FIXED: Initialize graph manager if not already done
if (!this.graphManager.isInitialized) {
console.log('🎯 Initializing graph manager...');
this.graphManager.initialize();
}
// FIXED: Force graph update and verify it worked
const previousNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0;
const previousEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0;
console.log('🔄 Before update - Nodes:', previousNodeCount, 'Edges:', previousEdgeCount);
// Store the data for debugging
this.lastGraphData = data.graph;
// Update the graph
this.graphManager.updateGraph(data.graph);
this.lastGraphUpdate = Date.now();
// Verify the update worked
const newNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0;
const newEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0;
console.log('🔄 After update - Nodes:', newNodeCount, 'Edges:', newEdgeCount);
if (newNodeCount !== data.graph.nodes.length || newEdgeCount !== data.graph.edges.length) {
console.warn('⚠️ Graph update mismatch!', {
expectedNodes: data.graph.nodes.length,
actualNodes: newNodeCount,
expectedEdges: data.graph.edges.length,
actualEdges: newEdgeCount
});
// Force a complete rebuild if there's a mismatch
console.log('🔧 Force rebuilding graph...');
this.graphManager.clear();
this.graphManager.updateGraph(data.graph);
}
console.log('✅ Graph updated successfully');
// FIXED: Force network redraw if we're using vis.js
if (this.graphManager.network) {
try {
this.graphManager.network.redraw();
console.log('🎨 Network redrawn');
} catch (redrawError) {
console.warn('⚠️ Network redraw failed:', redrawError);
}
}
} else {
if (!data.graph) {
console.log('⚠️ No graph data in WebSocket update');
}
if (!this.graphManager) {
console.log('⚠️ Graph manager not available');
}
}
} catch (error) {
console.error('❌ Error processing WebSocket update:', error);
console.error('Update data:', data);
console.error('Stack trace:', error.stack);
}
});
this.socket.on('error', (error) => {
console.error('❌ WebSocket error:', error);
this.showError('WebSocket communication error');
});
} catch (error) {
console.error('❌ Failed to initialize WebSocket:', error);
this.showError('Failed to establish real-time connection');
}
} }
/** /**
@ -280,12 +425,36 @@ class DNSReconApp {
} }
/** /**
* Initialize graph visualization * FIXED: Initialize graph visualization with enhanced debugging
*/ */
initializeGraph() { initializeGraph() {
try { try {
console.log('Initializing graph manager...'); console.log('Initializing graph manager...');
this.graphManager = new GraphManager('network-graph'); this.graphManager = new GraphManager('network-graph');
// FIXED: Add debugging hooks to graph manager
if (this.graphManager) {
// Override updateGraph to add debugging
const originalUpdateGraph = this.graphManager.updateGraph.bind(this.graphManager);
this.graphManager.updateGraph = (graphData) => {
console.log('🔧 GraphManager.updateGraph called with:', {
nodes: graphData?.nodes?.length || 0,
edges: graphData?.edges?.length || 0,
timestamp: new Date().toISOString()
});
const result = originalUpdateGraph(graphData);
console.log('🔧 GraphManager.updateGraph completed, network state:', {
networkExists: !!this.graphManager.network,
nodeDataSetLength: this.graphManager.nodes?.length || 0,
edgeDataSetLength: this.graphManager.edges?.length || 0
});
return result;
};
}
console.log('Graph manager initialized successfully'); console.log('Graph manager initialized successfully');
} catch (error) { } catch (error) {
console.error('Failed to initialize graph manager:', error); console.error('Failed to initialize graph manager:', error);
@ -305,7 +474,6 @@ class DNSReconApp {
console.log(`Target: "${target}", Max depth: ${maxDepth}`); console.log(`Target: "${target}", Max depth: ${maxDepth}`);
// Validation
if (!target) { if (!target) {
console.log('Validation failed: empty target'); console.log('Validation failed: empty target');
this.showError('Please enter a target domain or IP'); this.showError('Please enter a target domain or IP');
@ -320,6 +488,19 @@ class DNSReconApp {
return; return;
} }
// FIXED: Ensure WebSocket connection before starting scan
if (!this.isConnected) {
console.log('WebSocket not connected, attempting to connect...');
this.socket.connect();
// Wait a moment for connection
await new Promise(resolve => setTimeout(resolve, 1000));
if (!this.isConnected) {
this.showWarning('WebSocket connection not established. Updates may be delayed.');
}
}
console.log('Validation passed, setting UI state to scanning...'); console.log('Validation passed, setting UI state to scanning...');
this.setUIState('scanning'); this.setUIState('scanning');
this.showInfo('Starting reconnaissance scan...'); this.showInfo('Starting reconnaissance scan...');
@ -337,16 +518,28 @@ class DNSReconApp {
if (response.success) { if (response.success) {
this.currentSessionId = response.scan_id; this.currentSessionId = response.scan_id;
this.showSuccess('Reconnaissance scan started successfully'); this.showSuccess('Reconnaissance scan started - watching for real-time updates');
if (clearGraph) { if (clearGraph && this.graphManager) {
console.log('🧹 Clearing graph for new scan');
this.graphManager.clear(); this.graphManager.clear();
} }
console.log(`Scan started for ${target} with depth ${maxDepth}`); console.log(`Scan started for ${target} with depth ${maxDepth}`);
// Request initial status update via WebSocket // FIXED: Immediately start listening for updates
this.socket.emit('get_status'); if (this.socket && this.isConnected) {
console.log('📡 Requesting initial status update...');
this.socket.emit('get_status');
// Set up periodic status requests as backup (every 5 seconds during scan)
/*this.statusRequestInterval = setInterval(() => {
if (this.isScanning && this.socket && this.isConnected) {
console.log('📡 Periodic status request...');
this.socket.emit('get_status');
}
}, 5000);*/
}
} else { } else {
throw new Error(response.error || 'Failed to start scan'); throw new Error(response.error || 'Failed to start scan');
@ -358,26 +551,34 @@ class DNSReconApp {
this.setUIState('idle'); this.setUIState('idle');
} }
} }
/**
* Scan stop with immediate UI feedback // FIXED: Enhanced stop scan with interval cleanup
*/
async stopScan() { async stopScan() {
try { try {
console.log('Stopping scan...'); console.log('Stopping scan...');
// Immediately disable stop button and show stopping state // Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
if (this.elements.stopScan) { if (this.elements.stopScan) {
this.elements.stopScan.disabled = true; this.elements.stopScan.disabled = true;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOPPING]</span><span>Stopping...</span>'; this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOPPING]</span><span>Stopping...</span>';
} }
// Show immediate feedback
this.showInfo('Stopping scan...'); this.showInfo('Stopping scan...');
const response = await this.apiCall('/api/scan/stop', 'POST'); const response = await this.apiCall('/api/scan/stop', 'POST');
if (response.success) { if (response.success) {
this.showSuccess('Scan stop requested'); this.showSuccess('Scan stop requested');
// Request final status update
if (this.socket && this.isConnected) {
setTimeout(() => this.socket.emit('get_status'), 500);
}
} else { } else {
throw new Error(response.error || 'Failed to stop scan'); throw new Error(response.error || 'Failed to stop scan');
} }
@ -386,7 +587,6 @@ class DNSReconApp {
console.error('Failed to stop scan:', error); console.error('Failed to stop scan:', error);
this.showError(`Failed to stop scan: ${error.message}`); this.showError(`Failed to stop scan: ${error.message}`);
// Re-enable stop button on error
if (this.elements.stopScan) { if (this.elements.stopScan) {
this.elements.stopScan.disabled = false; this.elements.stopScan.disabled = false;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>'; this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>';
@ -543,23 +743,24 @@ class DNSReconApp {
} }
/** /**
* Update graph from server * FIXED: Update graph from server with enhanced debugging
*/ */
async updateGraph() { async updateGraph() {
try { try {
console.log('Updating graph...'); console.log('Updating graph via API call...');
const response = await this.apiCall('/api/graph'); const response = await this.apiCall('/api/graph');
if (response.success) { if (response.success) {
const graphData = response.graph; const graphData = response.graph;
console.log('Graph data received:'); console.log('Graph data received from API:');
console.log('- Nodes:', graphData.nodes ? graphData.nodes.length : 0); console.log('- Nodes:', graphData.nodes ? graphData.nodes.length : 0);
console.log('- Edges:', graphData.edges ? graphData.edges.length : 0); console.log('- Edges:', graphData.edges ? graphData.edges.length : 0);
// FIXED: Always update graph, even if empty - let GraphManager handle placeholder // FIXED: Always update graph, even if empty - let GraphManager handle placeholder
if (this.graphManager) { if (this.graphManager) {
console.log('🔧 Calling GraphManager.updateGraph from API response...');
this.graphManager.updateGraph(graphData); this.graphManager.updateGraph(graphData);
this.lastGraphUpdate = Date.now(); this.lastGraphUpdate = Date.now();
@ -568,6 +769,8 @@ class DNSReconApp {
if (this.elements.relationshipsDisplay) { if (this.elements.relationshipsDisplay) {
this.elements.relationshipsDisplay.textContent = edgeCount; this.elements.relationshipsDisplay.textContent = edgeCount;
} }
console.log('✅ Manual graph update completed');
} }
} else { } else {
console.error('Graph update failed:', response); console.error('Graph update failed:', response);
@ -663,12 +866,12 @@ class DNSReconApp {
* @param {string} newStatus - New scan status * @param {string} newStatus - New scan status
*/ */
handleStatusChange(newStatus, task_queue_size) { handleStatusChange(newStatus, task_queue_size) {
console.log(`=== STATUS CHANGE: ${this.scanStatus} -> ${newStatus} ===`); console.log(`📄 Status change handler: ${this.scanStatus}${newStatus}`);
switch (newStatus) { switch (newStatus) {
case 'running': case 'running':
this.setUIState('scanning', task_queue_size); this.setUIState('scanning', task_queue_size);
this.showSuccess('Scan is running'); this.showSuccess('Scan is running - updates in real-time');
this.updateConnectionStatus('active'); this.updateConnectionStatus('active');
break; break;
@ -677,8 +880,19 @@ class DNSReconApp {
this.showSuccess('Scan completed successfully'); this.showSuccess('Scan completed successfully');
this.updateConnectionStatus('completed'); this.updateConnectionStatus('completed');
this.loadProviders(); this.loadProviders();
// Force a final graph update console.log('✅ Scan completed - requesting final graph update');
console.log('Scan completed - forcing final graph update'); // Request final status to ensure we have the complete graph
setTimeout(() => {
if (this.socket && this.isConnected) {
this.socket.emit('get_status');
}
}, 1000);
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
break; break;
case 'failed': case 'failed':
@ -686,6 +900,12 @@ class DNSReconApp {
this.showError('Scan failed'); this.showError('Scan failed');
this.updateConnectionStatus('error'); this.updateConnectionStatus('error');
this.loadProviders(); this.loadProviders();
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
break; break;
case 'stopped': case 'stopped':
@ -693,11 +913,23 @@ class DNSReconApp {
this.showSuccess('Scan stopped'); this.showSuccess('Scan stopped');
this.updateConnectionStatus('stopped'); this.updateConnectionStatus('stopped');
this.loadProviders(); this.loadProviders();
// Clear status request interval
if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}
break; break;
case 'idle': case 'idle':
this.setUIState('idle', task_queue_size); this.setUIState('idle', task_queue_size);
this.updateConnectionStatus('idle'); this.updateConnectionStatus('idle');
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
break; break;
default: default:
@ -749,6 +981,7 @@ class DNSReconApp {
if (this.graphManager) { if (this.graphManager) {
this.graphManager.isScanning = true; this.graphManager.isScanning = true;
} }
if (this.elements.startScan) { if (this.elements.startScan) {
this.elements.startScan.disabled = true; this.elements.startScan.disabled = true;
this.elements.startScan.classList.add('loading'); this.elements.startScan.classList.add('loading');
@ -776,6 +1009,7 @@ class DNSReconApp {
if (this.graphManager) { if (this.graphManager) {
this.graphManager.isScanning = false; this.graphManager.isScanning = false;
} }
if (this.elements.startScan) { if (this.elements.startScan) {
this.elements.startScan.disabled = !isQueueEmpty; this.elements.startScan.disabled = !isQueueEmpty;
this.elements.startScan.classList.remove('loading'); this.elements.startScan.classList.remove('loading');
@ -1018,7 +1252,7 @@ class DNSReconApp {
} else { } else {
// API key not configured - ALWAYS show input field // API key not configured - ALWAYS show input field
const statusClass = info.enabled ? 'enabled' : 'api-key-required'; const statusClass = info.enabled ? 'enabled' : 'api-key-required';
const statusText = info.enabled ? ' Ready for API Key' : '⚠️ API Key Required'; const statusText = info.enabled ? ' Ready for API Key' : '⚠️ API Key Required';
inputGroup.innerHTML = ` inputGroup.innerHTML = `
<div class="provider-header"> <div class="provider-header">
@ -2000,8 +2234,8 @@ class DNSReconApp {
*/ */
getNodeTypeIcon(nodeType) { getNodeTypeIcon(nodeType) {
const icons = { const icons = {
'domain': '🌍', 'domain': '🌐',
'ip': '📍', 'ip': '🔢',
'asn': '🏢', 'asn': '🏢',
'large_entity': '📦', 'large_entity': '📦',
'correlation_object': '🔗' 'correlation_object': '🔗'