Compare commits
2 Commits
main
...
websockets
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4e6a8998a | ||
|
|
75a595c9cb |
173
app.py
173
app.py
@ -3,9 +3,9 @@
|
||||
"""
|
||||
Flask application entry point for DNSRecon web interface.
|
||||
Provides REST API endpoints and serves the web interface with user session support.
|
||||
FIXED: Enhanced WebSocket integration with proper connection management.
|
||||
"""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from flask import Flask, render_template, request, jsonify, send_file, session
|
||||
from datetime import datetime, timezone, timedelta
|
||||
@ -13,6 +13,7 @@ import io
|
||||
import os
|
||||
|
||||
from core.session_manager import session_manager
|
||||
from flask_socketio import SocketIO
|
||||
from config import config
|
||||
from core.graph_manager import NodeType
|
||||
from utils.helpers import is_valid_target
|
||||
@ -21,29 +22,38 @@ from decimal import Decimal
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*")
|
||||
app.config['SECRET_KEY'] = config.flask_secret_key
|
||||
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours)
|
||||
|
||||
def get_user_scanner():
|
||||
"""
|
||||
Retrieves the scanner for the current session, or creates a new one if none exists.
|
||||
FIXED: Retrieves the scanner for the current session with proper socketio management.
|
||||
"""
|
||||
current_flask_session_id = session.get('dnsrecon_session_id')
|
||||
|
||||
if current_flask_session_id:
|
||||
existing_scanner = session_manager.get_session(current_flask_session_id)
|
||||
if existing_scanner:
|
||||
# FIXED: Ensure socketio is properly maintained
|
||||
existing_scanner.socketio = socketio
|
||||
print(f"✓ Retrieved existing scanner for session {current_flask_session_id[:8]}... with socketio restored")
|
||||
return current_flask_session_id, existing_scanner
|
||||
|
||||
new_session_id = session_manager.create_session()
|
||||
# FIXED: Register socketio connection when creating new session
|
||||
new_session_id = session_manager.create_session(socketio)
|
||||
new_scanner = session_manager.get_session(new_session_id)
|
||||
|
||||
if not new_scanner:
|
||||
raise Exception("Failed to create new scanner session")
|
||||
|
||||
# FIXED: Ensure new scanner has socketio reference and register the connection
|
||||
new_scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(new_session_id, socketio)
|
||||
session['dnsrecon_session_id'] = new_session_id
|
||||
session.permanent = True
|
||||
|
||||
print(f"✓ Created new scanner for session {new_session_id[:8]}... with socketio registered")
|
||||
return new_session_id, new_scanner
|
||||
|
||||
|
||||
@ -56,7 +66,7 @@ def index():
|
||||
@app.route('/api/scan/start', methods=['POST'])
|
||||
def start_scan():
|
||||
"""
|
||||
Starts a new reconnaissance scan.
|
||||
FIXED: Starts a new reconnaissance scan with proper socketio management.
|
||||
"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
@ -80,9 +90,17 @@ def start_scan():
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference and is registered
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
print(f"🚀 Starting scan for {target} with socketio enabled and registered")
|
||||
|
||||
success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target)
|
||||
|
||||
if success:
|
||||
# Update session with socketio-enabled scanner
|
||||
session_manager.update_session_scanner(user_session_id, scanner)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': 'Reconnaissance scan started successfully',
|
||||
@ -111,6 +129,10 @@ def stop_scan():
|
||||
if not scanner.session_id:
|
||||
scanner.session_id = user_session_id
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
scanner.stop_scan()
|
||||
session_manager.set_stop_signal(user_session_id)
|
||||
session_manager.update_scanner_status(user_session_id, 'stopped')
|
||||
@ -127,37 +149,83 @@ def stop_scan():
|
||||
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
|
||||
|
||||
|
||||
@app.route('/api/scan/status', methods=['GET'])
|
||||
@socketio.on('connect')
|
||||
def handle_connect():
|
||||
"""
|
||||
FIXED: Handle WebSocket connection with proper session management.
|
||||
"""
|
||||
print(f'✓ WebSocket client connected: {request.sid}')
|
||||
|
||||
# Try to restore existing session connection
|
||||
current_flask_session_id = session.get('dnsrecon_session_id')
|
||||
if current_flask_session_id:
|
||||
# Register this socketio connection for the existing session
|
||||
session_manager.register_socketio_connection(current_flask_session_id, socketio)
|
||||
print(f'✓ Registered WebSocket for existing session: {current_flask_session_id[:8]}...')
|
||||
|
||||
# Immediately send current status to new connection
|
||||
get_scan_status()
|
||||
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect():
|
||||
"""
|
||||
FIXED: Handle WebSocket disconnection gracefully.
|
||||
"""
|
||||
print(f'✗ WebSocket client disconnected: {request.sid}')
|
||||
|
||||
# Note: We don't immediately remove the socketio connection from session_manager
|
||||
# because the user might reconnect. The cleanup will happen during session cleanup.
|
||||
|
||||
|
||||
@socketio.on('get_status')
|
||||
def get_scan_status():
|
||||
"""Get current scan status."""
|
||||
"""
|
||||
FIXED: Get current scan status and emit real-time update with proper error handling.
|
||||
"""
|
||||
try:
|
||||
user_session_id, scanner = get_user_scanner()
|
||||
|
||||
if not scanner:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'status': {
|
||||
'status': 'idle', 'target_domain': None, 'current_depth': 0,
|
||||
'max_depth': 0, 'progress_percentage': 0.0,
|
||||
'user_session_id': user_session_id
|
||||
}
|
||||
})
|
||||
status = {
|
||||
'status': 'idle',
|
||||
'target_domain': None,
|
||||
'current_depth': 0,
|
||||
'max_depth': 0,
|
||||
'progress_percentage': 0.0,
|
||||
'user_session_id': user_session_id,
|
||||
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
|
||||
}
|
||||
print(f"📡 Emitting idle status for session {user_session_id[:8] if user_session_id else 'none'}...")
|
||||
else:
|
||||
if not scanner.session_id:
|
||||
scanner.session_id = user_session_id
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference for future updates
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
status = scanner.get_scan_status()
|
||||
status['user_session_id'] = user_session_id
|
||||
|
||||
print(f"📡 Emitting status update: {status['status']} - "
|
||||
f"Nodes: {len(status.get('graph', {}).get('nodes', []))}, "
|
||||
f"Edges: {len(status.get('graph', {}).get('edges', []))}")
|
||||
|
||||
# Update session with socketio-enabled scanner
|
||||
session_manager.update_session_scanner(user_session_id, scanner)
|
||||
|
||||
if not scanner.session_id:
|
||||
scanner.session_id = user_session_id
|
||||
|
||||
status = scanner.get_scan_status()
|
||||
status['user_session_id'] = user_session_id
|
||||
|
||||
return jsonify({'success': True, 'status': status})
|
||||
socketio.emit('scan_update', status)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return jsonify({
|
||||
'success': False, 'error': f'Internal server error: {str(e)}',
|
||||
'fallback_status': {'status': 'error', 'progress_percentage': 0.0}
|
||||
}), 500
|
||||
|
||||
error_status = {
|
||||
'status': 'error',
|
||||
'message': 'Failed to get status',
|
||||
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
|
||||
}
|
||||
print(f"⚠️ Error getting status, emitting error status")
|
||||
socketio.emit('scan_update', error_status)
|
||||
|
||||
|
||||
@app.route('/api/graph', methods=['GET'])
|
||||
@ -174,6 +242,10 @@ def get_graph_data():
|
||||
if not scanner:
|
||||
return jsonify({'success': True, 'graph': empty_graph, 'user_session_id': user_session_id})
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
graph_data = scanner.get_graph_data() or empty_graph
|
||||
|
||||
return jsonify({'success': True, 'graph': graph_data, 'user_session_id': user_session_id})
|
||||
@ -200,6 +272,10 @@ def extract_from_large_entity():
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'No active session found'}), 404
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
success = scanner.extract_node_from_large_entity(large_entity_id, node_id)
|
||||
|
||||
if success:
|
||||
@ -220,6 +296,10 @@ def delete_graph_node(node_id):
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'No active session found'}), 404
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
success = scanner.graph.remove_node(node_id)
|
||||
|
||||
if success:
|
||||
@ -245,6 +325,10 @@ def revert_graph_action():
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'No active session found'}), 404
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
action_type = data['type']
|
||||
action_data = data['data']
|
||||
|
||||
@ -289,6 +373,10 @@ def export_results():
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
# Get export data using the new export manager
|
||||
try:
|
||||
results = export_manager.export_scan_results(scanner)
|
||||
@ -340,6 +428,10 @@ def export_targets():
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
# Use export manager for targets export
|
||||
targets_txt = export_manager.export_targets_list(scanner)
|
||||
|
||||
@ -370,6 +462,10 @@ def export_summary():
|
||||
if not scanner:
|
||||
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
# Use export manager for summary generation
|
||||
summary_txt = export_manager.generate_executive_summary(scanner)
|
||||
|
||||
@ -402,6 +498,10 @@ def set_api_keys():
|
||||
user_session_id, scanner = get_user_scanner()
|
||||
session_config = scanner.config
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
updated_providers = []
|
||||
|
||||
for provider_name, api_key in data.items():
|
||||
@ -434,6 +534,10 @@ def get_providers():
|
||||
user_session_id, scanner = get_user_scanner()
|
||||
base_provider_info = scanner.get_provider_info()
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
# Enhance provider info with API key source information
|
||||
enhanced_provider_info = {}
|
||||
|
||||
@ -498,6 +602,10 @@ def configure_providers():
|
||||
user_session_id, scanner = get_user_scanner()
|
||||
session_config = scanner.config
|
||||
|
||||
# FIXED: Ensure scanner has socketio reference
|
||||
scanner.socketio = socketio
|
||||
session_manager.register_socketio_connection(user_session_id, socketio)
|
||||
|
||||
updated_providers = []
|
||||
|
||||
for provider_name, settings in data.items():
|
||||
@ -526,7 +634,6 @@ def configure_providers():
|
||||
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
|
||||
|
||||
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(error):
|
||||
"""Handle 404 errors."""
|
||||
@ -542,9 +649,9 @@ def internal_error(error):
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.load_from_env()
|
||||
app.run(
|
||||
host=config.flask_host,
|
||||
port=config.flask_port,
|
||||
debug=config.flask_debug,
|
||||
threaded=True
|
||||
)
|
||||
print("🚀 Starting DNSRecon with enhanced WebSocket support...")
|
||||
print(f" Host: {config.flask_host}")
|
||||
print(f" Port: {config.flask_port}")
|
||||
print(f" Debug: {config.flask_debug}")
|
||||
print(" WebSocket: Enhanced connection management enabled")
|
||||
socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug)
|
||||
@ -4,8 +4,7 @@
|
||||
Graph data model for DNSRecon using NetworkX.
|
||||
Manages in-memory graph storage with confidence scoring and forensic metadata.
|
||||
Now fully compatible with the unified ProviderResult data model.
|
||||
UPDATED: Fixed correlation exclusion keys to match actual attribute names.
|
||||
UPDATED: Removed export_json() method - now handled by ExportManager.
|
||||
FIXED: Added proper pickle support to prevent weakref serialization errors.
|
||||
"""
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
@ -33,6 +32,7 @@ class GraphManager:
|
||||
Thread-safe graph manager for DNSRecon infrastructure mapping.
|
||||
Uses NetworkX for in-memory graph storage with confidence scoring.
|
||||
Compatible with unified ProviderResult data model.
|
||||
FIXED: Added proper pickle support to handle NetworkX graph serialization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -40,6 +40,57 @@ class GraphManager:
|
||||
self.graph = nx.DiGraph()
|
||||
self.creation_time = datetime.now(timezone.utc).isoformat()
|
||||
self.last_modified = self.creation_time
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
|
||||
state = self.__dict__.copy()
|
||||
|
||||
# Convert NetworkX graph to a serializable format
|
||||
if hasattr(self, 'graph') and self.graph:
|
||||
# Extract all nodes with their data
|
||||
nodes_data = {}
|
||||
for node_id, attrs in self.graph.nodes(data=True):
|
||||
nodes_data[node_id] = dict(attrs)
|
||||
|
||||
# Extract all edges with their data
|
||||
edges_data = []
|
||||
for source, target, attrs in self.graph.edges(data=True):
|
||||
edges_data.append({
|
||||
'source': source,
|
||||
'target': target,
|
||||
'attributes': dict(attrs)
|
||||
})
|
||||
|
||||
# Replace the NetworkX graph with serializable data
|
||||
state['_graph_nodes'] = nodes_data
|
||||
state['_graph_edges'] = edges_data
|
||||
del state['graph']
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore GraphManager after unpickling by reconstructing NetworkX graph."""
|
||||
# Restore basic attributes
|
||||
self.__dict__.update(state)
|
||||
|
||||
# Reconstruct NetworkX graph from serializable data
|
||||
self.graph = nx.DiGraph()
|
||||
|
||||
# Restore nodes
|
||||
if hasattr(self, '_graph_nodes'):
|
||||
for node_id, attrs in self._graph_nodes.items():
|
||||
self.graph.add_node(node_id, **attrs)
|
||||
del self._graph_nodes
|
||||
|
||||
# Restore edges
|
||||
if hasattr(self, '_graph_edges'):
|
||||
for edge_data in self._graph_edges:
|
||||
self.graph.add_edge(
|
||||
edge_data['source'],
|
||||
edge_data['target'],
|
||||
**edge_data['attributes']
|
||||
)
|
||||
del self._graph_edges
|
||||
|
||||
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
|
||||
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
|
||||
145
core/logger.py
145
core/logger.py
@ -40,6 +40,7 @@ class ForensicLogger:
|
||||
"""
|
||||
Thread-safe forensic logging system for DNSRecon.
|
||||
Maintains detailed audit trail of all reconnaissance activities.
|
||||
FIXED: Enhanced pickle support to prevent weakref issues in logging handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str = ""):
|
||||
@ -65,45 +66,74 @@ class ForensicLogger:
|
||||
'target_domains': set()
|
||||
}
|
||||
|
||||
# Configure standard logger
|
||||
# Configure standard logger with simple setup to avoid weakrefs
|
||||
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter for structured logging
|
||||
# Create minimal formatter
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Add console handler if not already present
|
||||
# Add console handler only if not already present (avoid duplicate handlers)
|
||||
if not self.logger.handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare ForensicLogger for pickling by excluding unpicklable objects."""
|
||||
"""
|
||||
FIXED: Prepare ForensicLogger for pickling by excluding problematic objects.
|
||||
"""
|
||||
state = self.__dict__.copy()
|
||||
# Remove the unpickleable 'logger' attribute
|
||||
if 'logger' in state:
|
||||
del state['logger']
|
||||
if 'lock' in state:
|
||||
del state['lock']
|
||||
|
||||
# Remove potentially unpickleable attributes that may contain weakrefs
|
||||
unpicklable_attrs = ['logger', 'lock']
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
|
||||
# Convert sets to lists for JSON serialization compatibility
|
||||
if 'session_metadata' in state:
|
||||
metadata = state['session_metadata'].copy()
|
||||
if 'providers_used' in metadata and isinstance(metadata['providers_used'], set):
|
||||
metadata['providers_used'] = list(metadata['providers_used'])
|
||||
if 'target_domains' in metadata and isinstance(metadata['target_domains'], set):
|
||||
metadata['target_domains'] = list(metadata['target_domains'])
|
||||
state['session_metadata'] = metadata
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore ForensicLogger after unpickling by reconstructing logger."""
|
||||
"""
|
||||
FIXED: Restore ForensicLogger after unpickling by reconstructing components.
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
# Re-initialize the 'logger' attribute
|
||||
|
||||
# Re-initialize threading lock
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Re-initialize logger with minimal setup
|
||||
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Only add handler if not already present
|
||||
if not self.logger.handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Convert lists back to sets if needed
|
||||
if 'session_metadata' in self.__dict__:
|
||||
metadata = self.session_metadata
|
||||
if 'providers_used' in metadata and isinstance(metadata['providers_used'], list):
|
||||
metadata['providers_used'] = set(metadata['providers_used'])
|
||||
if 'target_domains' in metadata and isinstance(metadata['target_domains'], list):
|
||||
metadata['target_domains'] = set(metadata['target_domains'])
|
||||
|
||||
def _generate_session_id(self) -> str:
|
||||
"""Generate unique session identifier."""
|
||||
@ -143,18 +173,23 @@ class ForensicLogger:
|
||||
discovery_context=discovery_context
|
||||
)
|
||||
|
||||
self.api_requests.append(api_request)
|
||||
self.session_metadata['total_requests'] += 1
|
||||
self.session_metadata['providers_used'].add(provider)
|
||||
with self.lock:
|
||||
self.api_requests.append(api_request)
|
||||
self.session_metadata['total_requests'] += 1
|
||||
self.session_metadata['providers_used'].add(provider)
|
||||
|
||||
if target_indicator:
|
||||
self.session_metadata['target_domains'].add(target_indicator)
|
||||
|
||||
if target_indicator:
|
||||
self.session_metadata['target_domains'].add(target_indicator)
|
||||
|
||||
# Log to standard logger
|
||||
if error:
|
||||
self.logger.error(f"API Request Failed.")
|
||||
else:
|
||||
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
|
||||
# Log to standard logger with error handling
|
||||
try:
|
||||
if error:
|
||||
self.logger.error(f"API Request Failed - {provider}: {url}")
|
||||
else:
|
||||
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
|
||||
except Exception:
|
||||
# If logging fails, continue without breaking the application
|
||||
pass
|
||||
|
||||
def log_relationship_discovery(self, source_node: str, target_node: str,
|
||||
relationship_type: str, confidence_score: float,
|
||||
@ -183,29 +218,44 @@ class ForensicLogger:
|
||||
discovery_method=discovery_method
|
||||
)
|
||||
|
||||
self.relationships.append(relationship)
|
||||
self.session_metadata['total_relationships'] += 1
|
||||
with self.lock:
|
||||
self.relationships.append(relationship)
|
||||
self.session_metadata['total_relationships'] += 1
|
||||
|
||||
self.logger.info(
|
||||
f"Relationship Discovered - {source_node} -> {target_node} "
|
||||
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
|
||||
)
|
||||
# Log to standard logger with error handling
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Relationship Discovered - {source_node} -> {target_node} "
|
||||
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
|
||||
)
|
||||
except Exception:
|
||||
# If logging fails, continue without breaking the application
|
||||
pass
|
||||
|
||||
def log_scan_start(self, target_domain: str, recursion_depth: int,
|
||||
enabled_providers: List[str]) -> None:
|
||||
"""Log the start of a reconnaissance scan."""
|
||||
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
|
||||
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
|
||||
|
||||
self.session_metadata['target_domains'].update(target_domain)
|
||||
try:
|
||||
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
|
||||
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
|
||||
|
||||
with self.lock:
|
||||
self.session_metadata['target_domains'].add(target_domain)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log_scan_complete(self) -> None:
|
||||
"""Log the completion of a reconnaissance scan."""
|
||||
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
|
||||
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
|
||||
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
|
||||
with self.lock:
|
||||
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
|
||||
# Convert sets to lists for serialization
|
||||
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
|
||||
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
|
||||
|
||||
self.logger.info(f"Scan Complete - Session: {self.session_id}")
|
||||
try:
|
||||
self.logger.info(f"Scan Complete - Session: {self.session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def export_audit_trail(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -214,12 +264,13 @@ class ForensicLogger:
|
||||
Returns:
|
||||
Dictionary containing complete session audit trail
|
||||
"""
|
||||
return {
|
||||
'session_metadata': self.session_metadata.copy(),
|
||||
'api_requests': [asdict(req) for req in self.api_requests],
|
||||
'relationships': [asdict(rel) for rel in self.relationships],
|
||||
'export_timestamp': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
with self.lock:
|
||||
return {
|
||||
'session_metadata': self.session_metadata.copy(),
|
||||
'api_requests': [asdict(req) for req in self.api_requests],
|
||||
'relationships': [asdict(rel) for rel in self.relationships],
|
||||
'export_timestamp': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
def get_forensic_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -229,7 +280,13 @@ class ForensicLogger:
|
||||
Dictionary containing summary statistics
|
||||
"""
|
||||
provider_stats = {}
|
||||
for provider in self.session_metadata['providers_used']:
|
||||
|
||||
# Ensure providers_used is a set for iteration
|
||||
providers_used = self.session_metadata['providers_used']
|
||||
if isinstance(providers_used, list):
|
||||
providers_used = set(providers_used)
|
||||
|
||||
for provider in providers_used:
|
||||
provider_requests = [req for req in self.api_requests if req.provider == provider]
|
||||
provider_relationships = [rel for rel in self.relationships if rel.provider == provider]
|
||||
|
||||
|
||||
600
core/scanner.py
600
core/scanner.py
@ -6,7 +6,6 @@ import os
|
||||
import importlib
|
||||
import redis
|
||||
import time
|
||||
import math
|
||||
import random # Imported for jitter
|
||||
from typing import List, Set, Dict, Any, Tuple, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -36,9 +35,10 @@ class Scanner:
|
||||
"""
|
||||
Main scanning orchestrator for DNSRecon passive reconnaissance.
|
||||
UNIFIED: Combines comprehensive features with improved display formatting.
|
||||
FIXED: Enhanced threading object initialization to prevent None references.
|
||||
"""
|
||||
|
||||
def __init__(self, session_config=None):
|
||||
def __init__(self, session_config=None, socketio=None):
|
||||
"""Initialize scanner with session-specific configuration."""
|
||||
try:
|
||||
# Use provided session config or create default
|
||||
@ -46,6 +46,12 @@ class Scanner:
|
||||
from core.session_config import create_session_config
|
||||
session_config = create_session_config()
|
||||
|
||||
# FIXED: Initialize all threading objects first
|
||||
self._initialize_threading_objects()
|
||||
|
||||
# Set socketio (but will be set to None for storage)
|
||||
self.socketio = socketio
|
||||
|
||||
self.config = session_config
|
||||
self.graph = GraphManager()
|
||||
self.providers = []
|
||||
@ -53,17 +59,12 @@ class Scanner:
|
||||
self.current_target = None
|
||||
self.current_depth = 0
|
||||
self.max_depth = 2
|
||||
self.stop_event = threading.Event()
|
||||
self.scan_thread = None
|
||||
self.session_id: Optional[str] = None # Will be set by session manager
|
||||
self.task_queue = PriorityQueue()
|
||||
self.target_retries = defaultdict(int)
|
||||
self.scan_failed_due_to_retries = False
|
||||
self.initial_targets = set()
|
||||
|
||||
# Thread-safe processing tracking (from Document 1)
|
||||
self.currently_processing = set()
|
||||
self.processing_lock = threading.Lock()
|
||||
# Display-friendly processing list (from Document 2)
|
||||
self.currently_processing_display = []
|
||||
|
||||
@ -81,9 +82,10 @@ class Scanner:
|
||||
self.max_workers = self.config.max_concurrent_requests
|
||||
self.executor = None
|
||||
|
||||
# Status logger thread with improved formatting
|
||||
self.status_logger_thread = None
|
||||
self.status_logger_stop_event = threading.Event()
|
||||
# Initialize collections that will be recreated during unpickling
|
||||
self.task_queue = PriorityQueue()
|
||||
self.target_retries = defaultdict(int)
|
||||
self.scan_failed_due_to_retries = False
|
||||
|
||||
# Initialize providers with session config
|
||||
self._initialize_providers()
|
||||
@ -99,12 +101,24 @@ class Scanner:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _initialize_threading_objects(self):
|
||||
"""
|
||||
FIXED: Initialize all threading objects with proper error handling.
|
||||
This method can be called during both __init__ and __setstate__.
|
||||
"""
|
||||
self.stop_event = threading.Event()
|
||||
self.processing_lock = threading.Lock()
|
||||
self.status_logger_stop_event = threading.Event()
|
||||
self.status_logger_thread = None
|
||||
|
||||
def _is_stop_requested(self) -> bool:
|
||||
"""
|
||||
Check if stop is requested using both local and Redis-based signals.
|
||||
This ensures reliable termination across process boundaries.
|
||||
FIXED: Added None check for stop_event.
|
||||
"""
|
||||
if self.stop_event.is_set():
|
||||
# FIXED: Ensure stop_event exists before checking
|
||||
if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set():
|
||||
return True
|
||||
|
||||
if self.session_id:
|
||||
@ -112,16 +126,24 @@ class Scanner:
|
||||
from core.session_manager import session_manager
|
||||
return session_manager.is_stop_requested(self.session_id)
|
||||
except Exception as e:
|
||||
# Fall back to local event
|
||||
return self.stop_event.is_set()
|
||||
# Fall back to local event if it exists
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
return self.stop_event.is_set()
|
||||
return False
|
||||
|
||||
return self.stop_event.is_set()
|
||||
# Final fallback
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
return self.stop_event.is_set()
|
||||
return False
|
||||
|
||||
def _set_stop_signal(self) -> None:
|
||||
"""
|
||||
Set stop signal both locally and in Redis.
|
||||
FIXED: Added None check for stop_event.
|
||||
"""
|
||||
self.stop_event.set()
|
||||
# FIXED: Ensure stop_event exists before setting
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.session_id:
|
||||
try:
|
||||
@ -143,7 +165,8 @@ class Scanner:
|
||||
'rate_limiter',
|
||||
'logger',
|
||||
'status_logger_thread',
|
||||
'status_logger_stop_event'
|
||||
'status_logger_stop_event',
|
||||
'socketio'
|
||||
]
|
||||
|
||||
for attr in unpicklable_attrs:
|
||||
@ -161,16 +184,21 @@ class Scanner:
|
||||
"""Restore object after unpickling by reconstructing threading objects."""
|
||||
self.__dict__.update(state)
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
# FIXED: Ensure all threading objects are properly initialized
|
||||
self._initialize_threading_objects()
|
||||
|
||||
# Re-initialize other objects
|
||||
self.scan_thread = None
|
||||
self.executor = None
|
||||
self.processing_lock = threading.Lock()
|
||||
self.task_queue = PriorityQueue()
|
||||
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
||||
self.logger = get_forensic_logger()
|
||||
self.status_logger_thread = None
|
||||
self.status_logger_stop_event = threading.Event()
|
||||
|
||||
# FIXED: Initialize socketio as None but preserve ability to set it
|
||||
if not hasattr(self, 'socketio'):
|
||||
self.socketio = None
|
||||
|
||||
# Initialize missing attributes with defaults
|
||||
if not hasattr(self, 'providers') or not self.providers:
|
||||
self._initialize_providers()
|
||||
|
||||
@ -180,11 +208,36 @@ class Scanner:
|
||||
if not hasattr(self, 'currently_processing_display'):
|
||||
self.currently_processing_display = []
|
||||
|
||||
if not hasattr(self, 'target_retries'):
|
||||
self.target_retries = defaultdict(int)
|
||||
|
||||
if not hasattr(self, 'scan_failed_due_to_retries'):
|
||||
self.scan_failed_due_to_retries = False
|
||||
|
||||
if not hasattr(self, 'initial_targets'):
|
||||
self.initial_targets = set()
|
||||
|
||||
# Ensure providers have stop events
|
||||
if hasattr(self, 'providers'):
|
||||
for provider in self.providers:
|
||||
if hasattr(provider, 'set_stop_event'):
|
||||
if hasattr(provider, 'set_stop_event') and self.stop_event:
|
||||
provider.set_stop_event(self.stop_event)
|
||||
|
||||
def _ensure_threading_objects_exist(self):
|
||||
"""
|
||||
FIXED: Utility method to ensure threading objects exist before use.
|
||||
Call this before any method that might use threading objects.
|
||||
"""
|
||||
if not hasattr(self, 'stop_event') or self.stop_event is None:
|
||||
print("WARNING: Threading objects not initialized, recreating...")
|
||||
self._initialize_threading_objects()
|
||||
|
||||
if not hasattr(self, 'processing_lock') or self.processing_lock is None:
|
||||
self.processing_lock = threading.Lock()
|
||||
|
||||
if not hasattr(self, 'task_queue') or self.task_queue is None:
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
def _initialize_providers(self) -> None:
|
||||
"""Initialize all available providers based on session configuration."""
|
||||
self.providers = []
|
||||
@ -222,7 +275,9 @@ class Scanner:
|
||||
print(f" Available: {is_available}")
|
||||
|
||||
if is_available:
|
||||
provider.set_stop_event(self.stop_event)
|
||||
# FIXED: Ensure stop_event exists before setting it
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
provider.set_stop_event(self.stop_event)
|
||||
if isinstance(provider, CorrelationProvider):
|
||||
provider.set_graph_manager(self.graph)
|
||||
self.providers.append(provider)
|
||||
@ -252,15 +307,25 @@ class Scanner:
|
||||
BOLD = "\033[1m"
|
||||
|
||||
last_status_str = ""
|
||||
while not self.status_logger_stop_event.is_set():
|
||||
|
||||
# FIXED: Ensure threading objects exist
|
||||
self._ensure_threading_objects_exist()
|
||||
|
||||
while not (hasattr(self, 'status_logger_stop_event') and
|
||||
self.status_logger_stop_event and
|
||||
self.status_logger_stop_event.is_set()):
|
||||
try:
|
||||
with self.processing_lock:
|
||||
in_flight_tasks = list(self.currently_processing)
|
||||
self.currently_processing_display = in_flight_tasks.copy()
|
||||
# FIXED: Check if processing_lock exists before using
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
in_flight_tasks = list(self.currently_processing)
|
||||
self.currently_processing_display = in_flight_tasks.copy()
|
||||
else:
|
||||
in_flight_tasks = list(getattr(self, 'currently_processing', []))
|
||||
|
||||
status_str = (
|
||||
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
|
||||
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
|
||||
f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | "
|
||||
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
|
||||
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
|
||||
f"Skipped: {self.tasks_skipped} | "
|
||||
@ -288,22 +353,30 @@ class Scanner:
|
||||
time.sleep(2)
|
||||
|
||||
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
|
||||
"""
|
||||
FIXED: Enhanced start_scan with proper threading object initialization and socketio management.
|
||||
"""
|
||||
# FIXED: Ensure threading objects exist before proceeding
|
||||
self._ensure_threading_objects_exist()
|
||||
|
||||
if self.scan_thread and self.scan_thread.is_alive():
|
||||
self.logger.logger.info("Stopping existing scan before starting new one")
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
|
||||
# Clean up processing state
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
# Clear task queue
|
||||
while not self.task_queue.empty():
|
||||
try:
|
||||
self.task_queue.get_nowait()
|
||||
except:
|
||||
break
|
||||
if hasattr(self, 'task_queue') and self.task_queue:
|
||||
while not self.task_queue.empty():
|
||||
try:
|
||||
self.task_queue.get_nowait()
|
||||
except:
|
||||
break
|
||||
|
||||
# Shutdown executor
|
||||
if self.executor:
|
||||
@ -320,14 +393,26 @@ class Scanner:
|
||||
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
|
||||
|
||||
self.status = ScanStatus.IDLE
|
||||
self.stop_event.clear()
|
||||
|
||||
# FIXED: Ensure stop_event exists before clearing
|
||||
if hasattr(self, 'stop_event') and self.stop_event:
|
||||
self.stop_event.clear()
|
||||
|
||||
if self.session_id:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.clear_stop_signal(self.session_id)
|
||||
|
||||
# FIXED: Restore socketio connection if missing
|
||||
if not hasattr(self, 'socketio') or not self.socketio:
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"✓ Restored socketio connection for scan start")
|
||||
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
# FIXED: Safe cleanup with existence checks
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
self.task_queue = PriorityQueue()
|
||||
@ -395,7 +480,10 @@ class Scanner:
|
||||
)
|
||||
self.scan_thread.start()
|
||||
|
||||
self.status_logger_stop_event.clear()
|
||||
# FIXED: Ensure status_logger_stop_event exists before clearing
|
||||
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
|
||||
self.status_logger_stop_event.clear()
|
||||
|
||||
self.status_logger_thread = threading.Thread(
|
||||
target=self._status_logger_thread,
|
||||
daemon=True,
|
||||
@ -449,6 +537,13 @@ class Scanner:
|
||||
return 10 # Very low rate limit = very low priority
|
||||
|
||||
def _execute_scan(self, target: str, max_depth: int) -> None:
|
||||
"""
|
||||
FIXED: Enhanced execute_scan with proper threading object handling.
|
||||
"""
|
||||
# FIXED: Ensure threading objects exist
|
||||
self._ensure_threading_objects_exist()
|
||||
update_counter = 0 # Track updates for throttling
|
||||
last_update_time = time.time()
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
|
||||
|
||||
@ -480,8 +575,13 @@ class Scanner:
|
||||
print(f"\n=== PHASE 1: Running non-correlation providers ===")
|
||||
while not self._is_stop_requested():
|
||||
queue_empty = self.task_queue.empty()
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
|
||||
# FIXED: Safe processing lock usage
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
else:
|
||||
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
|
||||
|
||||
if queue_empty and no_active_processing:
|
||||
consecutive_empty_iterations += 1
|
||||
@ -534,10 +634,23 @@ class Scanner:
|
||||
continue
|
||||
|
||||
# Thread-safe processing state management
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
|
||||
# FIXED: Safe processing lock usage
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
continue
|
||||
self.currently_processing.add(processing_key)
|
||||
else:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
processing_key = (provider_name, target_item)
|
||||
if not hasattr(self, 'currently_processing'):
|
||||
self.currently_processing = set()
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
@ -556,7 +669,12 @@ class Scanner:
|
||||
|
||||
if provider and not isinstance(provider, CorrelationProvider):
|
||||
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
|
||||
|
||||
update_counter += 1
|
||||
current_time = time.time()
|
||||
if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0):
|
||||
self._update_session_state()
|
||||
last_update_time = current_time
|
||||
update_counter = 0
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
|
||||
@ -601,9 +719,13 @@ class Scanner:
|
||||
self.indicators_completed += 1
|
||||
|
||||
finally:
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
self.currently_processing.discard(processing_key)
|
||||
# FIXED: Safe processing lock usage for cleanup
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.discard(processing_key)
|
||||
else:
|
||||
if hasattr(self, 'currently_processing'):
|
||||
self.currently_processing.discard(processing_key)
|
||||
|
||||
# PHASE 2: Run correlations on all discovered nodes
|
||||
if not self._is_stop_requested():
|
||||
@ -616,8 +738,9 @@ class Scanner:
|
||||
self.logger.logger.error(f"Scan failed: {e}")
|
||||
finally:
|
||||
# Comprehensive cleanup (same as before)
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
while not self.task_queue.empty():
|
||||
@ -633,7 +756,9 @@ class Scanner:
|
||||
else:
|
||||
self.status = ScanStatus.COMPLETED
|
||||
|
||||
self.status_logger_stop_event.set()
|
||||
# FIXED: Safe stop event handling
|
||||
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
|
||||
self.status_logger_stop_event.set()
|
||||
if self.status_logger_thread and self.status_logger_thread.is_alive():
|
||||
self.status_logger_thread.join(timeout=2.0)
|
||||
|
||||
@ -687,8 +812,13 @@ class Scanner:
|
||||
|
||||
while not self._is_stop_requested() and correlation_tasks:
|
||||
queue_empty = self.task_queue.empty()
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
|
||||
# FIXED: Safe processing check
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
no_active_processing = len(self.currently_processing) == 0
|
||||
else:
|
||||
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
|
||||
|
||||
if queue_empty and no_active_processing:
|
||||
consecutive_empty_iterations += 1
|
||||
@ -720,10 +850,23 @@ class Scanner:
|
||||
correlation_tasks.remove(task_tuple)
|
||||
continue
|
||||
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
|
||||
# FIXED: Safe processing management
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
continue
|
||||
self.currently_processing.add(processing_key)
|
||||
else:
|
||||
if self._is_stop_requested():
|
||||
break
|
||||
processing_key = (provider_name, target_item)
|
||||
if not hasattr(self, 'currently_processing'):
|
||||
self.currently_processing = set()
|
||||
if processing_key in self.currently_processing:
|
||||
self.tasks_skipped += 1
|
||||
self.indicators_completed += 1
|
||||
@ -752,51 +895,214 @@ class Scanner:
|
||||
correlation_tasks.remove(task_tuple)
|
||||
|
||||
finally:
|
||||
with self.processing_lock:
|
||||
processing_key = (provider_name, target_item)
|
||||
self.currently_processing.discard(processing_key)
|
||||
# FIXED: Safe cleanup
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.discard(processing_key)
|
||||
else:
|
||||
if hasattr(self, 'currently_processing'):
|
||||
self.currently_processing.discard(processing_key)
|
||||
|
||||
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
|
||||
|
||||
# Rest of the methods remain the same but with similar threading object safety checks...
|
||||
# I'll include the key methods that need fixes:
|
||||
|
||||
def stop_scan(self) -> bool:
|
||||
"""Request immediate scan termination with proper cleanup."""
|
||||
try:
|
||||
self.logger.logger.info("Scan termination requested by user")
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
|
||||
# FIXED: Safe cleanup
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
if self.executor:
|
||||
try:
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._update_session_state()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error during scan termination: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def get_scan_status(self) -> Dict[str, Any]:
|
||||
"""Get current scan status with comprehensive graph data for real-time updates."""
|
||||
try:
|
||||
# FIXED: Safe processing state access
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
currently_processing_count = len(self.currently_processing)
|
||||
currently_processing_list = list(self.currently_processing)
|
||||
else:
|
||||
currently_processing_count = len(getattr(self, 'currently_processing', []))
|
||||
currently_processing_list = list(getattr(self, 'currently_processing', []))
|
||||
|
||||
# FIXED: Always include complete graph data for real-time updates
|
||||
graph_data = self.get_graph_data()
|
||||
|
||||
return {
|
||||
'status': self.status,
|
||||
'target_domain': self.current_target,
|
||||
'current_depth': self.current_depth,
|
||||
'max_depth': self.max_depth,
|
||||
'current_indicator': self.current_indicator,
|
||||
'indicators_processed': self.indicators_processed,
|
||||
'indicators_completed': self.indicators_completed,
|
||||
'tasks_re_enqueued': self.tasks_re_enqueued,
|
||||
'progress_percentage': self._calculate_progress(),
|
||||
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
|
||||
'enabled_providers': [provider.get_name() for provider in self.providers],
|
||||
'graph': graph_data, # FIXED: Always include complete graph data
|
||||
'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
|
||||
'currently_processing_count': currently_processing_count,
|
||||
'currently_processing': currently_processing_list[:5],
|
||||
'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
|
||||
'tasks_completed': self.indicators_completed,
|
||||
'tasks_skipped': self.tasks_skipped,
|
||||
'tasks_rescheduled': self.tasks_re_enqueued,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': 'Failed to get status',
|
||||
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
|
||||
}
|
||||
|
||||
def _update_session_state(self) -> None:
|
||||
"""
|
||||
FIXED: Update the scanner state in Redis and emit real-time WebSocket updates.
|
||||
Enhanced with better error handling and socketio management.
|
||||
"""
|
||||
if self.session_id:
|
||||
try:
|
||||
# Get current status for WebSocket emission
|
||||
current_status = self.get_scan_status()
|
||||
|
||||
# FIXED: Emit real-time update via WebSocket with better error handling
|
||||
socketio_available = False
|
||||
if hasattr(self, 'socketio') and self.socketio:
|
||||
try:
|
||||
print(f"📡 EMITTING WebSocket update: {current_status.get('status', 'unknown')} - "
|
||||
f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, "
|
||||
f"Edges: {len(current_status.get('graph', {}).get('edges', []))}")
|
||||
|
||||
self.socketio.emit('scan_update', current_status)
|
||||
print("✅ WebSocket update emitted successfully")
|
||||
socketio_available = True
|
||||
|
||||
except Exception as ws_error:
|
||||
print(f"⚠️ WebSocket emission failed: {ws_error}")
|
||||
# Try to get socketio from session manager
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
print("🔄 Attempting to use registered socketio connection...")
|
||||
registered_socketio.emit('scan_update', current_status)
|
||||
self.socketio = registered_socketio # Update our reference
|
||||
print("✅ WebSocket update emitted via registered connection")
|
||||
socketio_available = True
|
||||
else:
|
||||
print("⚠️ No registered socketio connection found")
|
||||
except Exception as fallback_error:
|
||||
print(f"⚠️ Fallback socketio emission also failed: {fallback_error}")
|
||||
else:
|
||||
# Try to restore socketio from session manager
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
print(f"🔄 Restoring socketio connection for session {self.session_id}")
|
||||
self.socketio = registered_socketio
|
||||
self.socketio.emit('scan_update', current_status)
|
||||
print("✅ WebSocket update emitted via restored connection")
|
||||
socketio_available = True
|
||||
else:
|
||||
print(f"⚠️ No socketio connection available for session {self.session_id}")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio connection: {restore_error}")
|
||||
|
||||
if not socketio_available:
|
||||
print(f"⚠️ Real-time updates unavailable for session {self.session_id}")
|
||||
|
||||
# Update session in Redis for persistence (always do this)
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.update_session_scanner(self.session_id, self)
|
||||
except Exception as redis_error:
|
||||
print(f"⚠️ Failed to update session in Redis: {redis_error}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error updating session state: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
||||
"""
|
||||
Manages the entire process for a given target and provider.
|
||||
This version is generalized to handle all relationships dynamically.
|
||||
FIXED: Manages the entire process for a given target and provider with enhanced real-time updates.
|
||||
"""
|
||||
if self._is_stop_requested():
|
||||
return set(), set(), False
|
||||
|
||||
|
||||
is_ip = _is_valid_ip(target)
|
||||
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
||||
|
||||
|
||||
self.graph.add_node(target, target_type)
|
||||
self._initialize_provider_states(target)
|
||||
|
||||
|
||||
new_targets = set()
|
||||
provider_successful = True
|
||||
|
||||
|
||||
try:
|
||||
provider_result = self._execute_provider_query(provider, target, is_ip)
|
||||
|
||||
|
||||
if provider_result is None:
|
||||
provider_successful = False
|
||||
elif not self._is_stop_requested():
|
||||
# Pass all relationships to be processed
|
||||
discovered, is_large_entity = self._process_provider_result_unified(
|
||||
target, provider, provider_result, depth
|
||||
)
|
||||
new_targets.update(discovered)
|
||||
|
||||
|
||||
# FIXED: Emit real-time update after processing provider result
|
||||
if discovered or provider_result.get_relationship_count() > 0:
|
||||
# Ensure we have socketio connection for real-time updates
|
||||
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"🔄 Restored socketio connection during provider processing")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio during provider processing: {restore_error}")
|
||||
|
||||
self._update_session_state()
|
||||
|
||||
except Exception as e:
|
||||
provider_successful = False
|
||||
self._log_provider_error(target, provider.get_name(), str(e))
|
||||
|
||||
|
||||
return new_targets, set(), provider_successful
|
||||
|
||||
|
||||
|
||||
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
|
||||
"""
|
||||
The "worker" function that directly communicates with the provider to fetch data.
|
||||
"""
|
||||
"""The "worker" function that directly communicates with the provider to fetch data."""
|
||||
provider_name = provider.get_name()
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
@ -823,9 +1129,7 @@ class Scanner:
|
||||
|
||||
def _create_large_entity_from_result(self, source_node: str, provider_name: str,
|
||||
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
|
||||
"""
|
||||
Creates a large entity node, tags all member nodes, and returns its ID and members.
|
||||
"""
|
||||
"""Creates a large entity node, tags all member nodes, and returns its ID and members."""
|
||||
members = {rel.target_node for rel in provider_result.relationships
|
||||
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
|
||||
|
||||
@ -858,7 +1162,7 @@ class Scanner:
|
||||
|
||||
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
|
||||
"""
|
||||
Removes a node from a large entity, allowing it to be processed normally.
|
||||
FIXED: Removes a node from a large entity with immediate real-time update.
|
||||
"""
|
||||
if not self.graph.graph.has_node(node_id):
|
||||
return False
|
||||
@ -877,7 +1181,6 @@ class Scanner:
|
||||
for provider in eligible_providers:
|
||||
provider_name = provider.get_name()
|
||||
priority = self._get_priority(provider_name)
|
||||
# Use current depth of the large entity if available, else 0
|
||||
depth = 0
|
||||
if self.graph.graph.has_node(large_entity_id):
|
||||
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
||||
@ -888,6 +1191,19 @@ class Scanner:
|
||||
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
|
||||
self.total_tasks_ever_enqueued += 1
|
||||
|
||||
# FIXED: Emit real-time update after extraction with socketio management
|
||||
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"🔄 Restored socketio for node extraction update")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio for extraction: {restore_error}")
|
||||
|
||||
self._update_session_state()
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -895,8 +1211,7 @@ class Scanner:
|
||||
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
|
||||
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
|
||||
"""
|
||||
Process a unified ProviderResult object to update the graph.
|
||||
This version dynamically re-routes edges to a large entity container.
|
||||
FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates.
|
||||
"""
|
||||
provider_name = provider.get_name()
|
||||
discovered_targets = set()
|
||||
@ -916,6 +1231,10 @@ class Scanner:
|
||||
target, provider_name, provider_result, current_depth
|
||||
)
|
||||
|
||||
# Track if we added anything significant
|
||||
nodes_added = 0
|
||||
edges_added = 0
|
||||
|
||||
for i, relationship in enumerate(provider_result.relationships):
|
||||
if i % 5 == 0 and self._is_stop_requested():
|
||||
break
|
||||
@ -947,17 +1266,20 @@ class Scanner:
|
||||
max_depth_reached = current_depth >= self.max_depth
|
||||
|
||||
# Add actual nodes to the graph (they might be hidden by the UI)
|
||||
self.graph.add_node(source_node_id, source_type)
|
||||
self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached})
|
||||
if self.graph.add_node(source_node_id, source_type):
|
||||
nodes_added += 1
|
||||
if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}):
|
||||
nodes_added += 1
|
||||
|
||||
# Add the visual edge to the graph
|
||||
self.graph.add_edge(
|
||||
if self.graph.add_edge(
|
||||
visual_source, visual_target,
|
||||
relationship.relationship_type,
|
||||
relationship.confidence,
|
||||
provider_name,
|
||||
relationship.raw_data
|
||||
)
|
||||
):
|
||||
edges_added += 1
|
||||
|
||||
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
|
||||
if target_node_id not in large_entity_members:
|
||||
@ -985,86 +1307,32 @@ class Scanner:
|
||||
if not self.graph.graph.has_node(node_id):
|
||||
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
|
||||
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
|
||||
nodes_added += 1
|
||||
else:
|
||||
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
|
||||
|
||||
return discovered_targets, is_large_entity
|
||||
|
||||
def stop_scan(self) -> bool:
|
||||
"""Request immediate scan termination with proper cleanup."""
|
||||
try:
|
||||
self.logger.logger.info("Scan termination requested by user")
|
||||
self._set_stop_signal()
|
||||
self.status = ScanStatus.STOPPED
|
||||
# FIXED: Emit real-time update if we added anything significant
|
||||
if nodes_added > 0 or edges_added > 0:
|
||||
print(f"🔄 Added {nodes_added} nodes, {edges_added} edges - triggering real-time update")
|
||||
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.currently_processing_display = []
|
||||
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
if self.executor:
|
||||
# Ensure we have socketio connection for immediate update
|
||||
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
|
||||
try:
|
||||
self.executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from core.session_manager import session_manager
|
||||
registered_socketio = session_manager.get_socketio_connection(self.session_id)
|
||||
if registered_socketio:
|
||||
self.socketio = registered_socketio
|
||||
print(f"🔄 Restored socketio for immediate update")
|
||||
except Exception as restore_error:
|
||||
print(f"⚠️ Failed to restore socketio for immediate update: {restore_error}")
|
||||
|
||||
self._update_session_state()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error during scan termination: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _update_session_state(self) -> None:
|
||||
"""
|
||||
Update the scanner state in Redis for GUI updates.
|
||||
"""
|
||||
if self.session_id:
|
||||
try:
|
||||
from core.session_manager import session_manager
|
||||
session_manager.update_session_scanner(self.session_id, self)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_scan_status(self) -> Dict[str, Any]:
|
||||
"""Get current scan status with comprehensive processing information."""
|
||||
try:
|
||||
with self.processing_lock:
|
||||
currently_processing_count = len(self.currently_processing)
|
||||
currently_processing_list = list(self.currently_processing)
|
||||
|
||||
return {
|
||||
'status': self.status,
|
||||
'target_domain': self.current_target,
|
||||
'current_depth': self.current_depth,
|
||||
'max_depth': self.max_depth,
|
||||
'current_indicator': self.current_indicator,
|
||||
'indicators_processed': self.indicators_processed,
|
||||
'indicators_completed': self.indicators_completed,
|
||||
'tasks_re_enqueued': self.tasks_re_enqueued,
|
||||
'progress_percentage': self._calculate_progress(),
|
||||
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
|
||||
'enabled_providers': [provider.get_name() for provider in self.providers],
|
||||
'graph_statistics': self.graph.get_statistics(),
|
||||
'task_queue_size': self.task_queue.qsize(),
|
||||
'currently_processing_count': currently_processing_count,
|
||||
'currently_processing': currently_processing_list[:5],
|
||||
'tasks_in_queue': self.task_queue.qsize(),
|
||||
'tasks_completed': self.indicators_completed,
|
||||
'tasks_skipped': self.tasks_skipped,
|
||||
'tasks_rescheduled': self.tasks_re_enqueued,
|
||||
}
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return { 'status': 'error', 'message': 'Failed to get status' }
|
||||
|
||||
return discovered_targets, is_large_entity
|
||||
|
||||
def _initialize_provider_states(self, target: str) -> None:
|
||||
"""
|
||||
FIXED: Safer provider state initialization with error handling.
|
||||
"""
|
||||
"""FIXED: Safer provider state initialization with error handling."""
|
||||
try:
|
||||
if not self.graph.graph.has_node(target):
|
||||
return
|
||||
@ -1077,11 +1345,8 @@ class Scanner:
|
||||
except Exception as e:
|
||||
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
|
||||
|
||||
|
||||
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
|
||||
"""
|
||||
FIXED: Improved provider eligibility checking with better filtering.
|
||||
"""
|
||||
"""FIXED: Improved provider eligibility checking with better filtering."""
|
||||
if dns_only:
|
||||
return [p for p in self.providers if p.get_name() == 'dns']
|
||||
|
||||
@ -1120,9 +1385,7 @@ class Scanner:
|
||||
return eligible
|
||||
|
||||
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
|
||||
"""
|
||||
FIXED: More robust check for already queried providers with proper error handling.
|
||||
"""
|
||||
"""FIXED: More robust check for already queried providers with proper error handling."""
|
||||
try:
|
||||
if not self.graph.graph.has_node(target):
|
||||
return False
|
||||
@ -1141,9 +1404,7 @@ class Scanner:
|
||||
|
||||
def _update_provider_state(self, target: str, provider_name: str, status: str,
|
||||
results_count: int, error: Optional[str], start_time: datetime) -> None:
|
||||
"""
|
||||
FIXED: More robust provider state updates with validation.
|
||||
"""
|
||||
"""FIXED: More robust provider state updates with validation."""
|
||||
try:
|
||||
if not self.graph.graph.has_node(target):
|
||||
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
|
||||
@ -1170,7 +1431,8 @@ class Scanner:
|
||||
}
|
||||
|
||||
# Update last modified time for forensic integrity
|
||||
self.last_modified = datetime.now(timezone.utc).isoformat()
|
||||
if hasattr(self, 'last_modified'):
|
||||
self.last_modified = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
|
||||
@ -1187,9 +1449,14 @@ class Scanner:
|
||||
return 0.0
|
||||
|
||||
# Add small buffer for tasks still in queue to avoid showing 100% too early
|
||||
queue_size = max(0, self.task_queue.qsize())
|
||||
with self.processing_lock:
|
||||
active_tasks = len(self.currently_processing)
|
||||
queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0)
|
||||
|
||||
# FIXED: Safe processing count
|
||||
if hasattr(self, 'processing_lock') and self.processing_lock:
|
||||
with self.processing_lock:
|
||||
active_tasks = len(self.currently_processing)
|
||||
else:
|
||||
active_tasks = len(getattr(self, 'currently_processing', []))
|
||||
|
||||
# Adjust total to account for remaining work
|
||||
adjusted_total = max(self.total_tasks_ever_enqueued,
|
||||
@ -1206,12 +1473,13 @@ class Scanner:
|
||||
return 0.0
|
||||
|
||||
def get_graph_data(self) -> Dict[str, Any]:
|
||||
"""Get current graph data formatted for frontend visualization."""
|
||||
graph_data = self.graph.get_graph_data()
|
||||
graph_data['initial_targets'] = list(self.initial_targets)
|
||||
return graph_data
|
||||
|
||||
|
||||
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get comprehensive information about all available providers."""
|
||||
info = {}
|
||||
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
||||
for filename in os.listdir(provider_dir):
|
||||
|
||||
@ -6,6 +6,7 @@ import uuid
|
||||
import redis
|
||||
import pickle
|
||||
from typing import Dict, Optional, Any
|
||||
import copy
|
||||
|
||||
from core.scanner import Scanner
|
||||
from config import config
|
||||
@ -13,7 +14,7 @@ from config import config
|
||||
class SessionManager:
|
||||
"""
|
||||
FIXED: Manages multiple scanner instances for concurrent user sessions using Redis.
|
||||
Now more conservative about session creation to preserve API keys and configuration.
|
||||
Enhanced to properly maintain WebSocket connections throughout scan lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, session_timeout_minutes: int = 0):
|
||||
@ -30,6 +31,9 @@ class SessionManager:
|
||||
# FIXED: Add a creation lock to prevent race conditions
|
||||
self.creation_lock = threading.Lock()
|
||||
|
||||
# Track active socketio connections per session
|
||||
self.active_socketio_connections = {}
|
||||
|
||||
# Start cleanup thread
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
@ -40,7 +44,7 @@ class SessionManager:
|
||||
"""Prepare SessionManager for pickling."""
|
||||
state = self.__dict__.copy()
|
||||
# Exclude unpickleable attributes - Redis client and threading objects
|
||||
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock']
|
||||
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections']
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
@ -53,6 +57,7 @@ class SessionManager:
|
||||
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
|
||||
self.lock = threading.Lock()
|
||||
self.creation_lock = threading.Lock()
|
||||
self.active_socketio_connections = {}
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
|
||||
@ -64,22 +69,70 @@ class SessionManager:
|
||||
"""Generates the Redis key for a session's stop signal."""
|
||||
return f"dnsrecon:stop:{session_id}"
|
||||
|
||||
def create_session(self) -> str:
|
||||
def register_socketio_connection(self, session_id: str, socketio) -> None:
|
||||
"""
|
||||
FIXED: Create a new user session with thread-safe creation to prevent duplicates.
|
||||
FIXED: Register a socketio connection for a session.
|
||||
This ensures the connection is maintained throughout the session lifecycle.
|
||||
"""
|
||||
with self.lock:
|
||||
self.active_socketio_connections[session_id] = socketio
|
||||
print(f"Registered socketio connection for session {session_id}")
|
||||
|
||||
def get_socketio_connection(self, session_id: str):
|
||||
"""
|
||||
FIXED: Get the active socketio connection for a session.
|
||||
"""
|
||||
with self.lock:
|
||||
return self.active_socketio_connections.get(session_id)
|
||||
|
||||
def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner:
|
||||
"""
|
||||
FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects.
|
||||
Now preserves socketio connection info for restoration.
|
||||
"""
|
||||
# Set the session ID on the scanner for cross-process stop signal management
|
||||
scanner.session_id = session_id
|
||||
|
||||
# FIXED: Don't set socketio to None if we want to preserve real-time updates
|
||||
# Instead, we'll restore it when loading the scanner
|
||||
scanner.socketio = None
|
||||
|
||||
# Force cleanup of any threading objects that might cause issues
|
||||
if hasattr(scanner, 'stop_event'):
|
||||
scanner.stop_event = None
|
||||
if hasattr(scanner, 'scan_thread'):
|
||||
scanner.scan_thread = None
|
||||
if hasattr(scanner, 'executor'):
|
||||
scanner.executor = None
|
||||
if hasattr(scanner, 'status_logger_thread'):
|
||||
scanner.status_logger_thread = None
|
||||
if hasattr(scanner, 'status_logger_stop_event'):
|
||||
scanner.status_logger_stop_event = None
|
||||
|
||||
return scanner
|
||||
|
||||
def create_session(self, socketio=None) -> str:
|
||||
"""
|
||||
FIXED: Create a new user session with enhanced WebSocket management.
|
||||
"""
|
||||
# FIXED: Use creation lock to prevent race conditions
|
||||
with self.creation_lock:
|
||||
session_id = str(uuid.uuid4())
|
||||
print(f"=== CREATING SESSION {session_id} IN REDIS ===")
|
||||
|
||||
# FIXED: Register socketio connection first
|
||||
if socketio:
|
||||
self.register_socketio_connection(session_id, socketio)
|
||||
|
||||
try:
|
||||
from core.session_config import create_session_config
|
||||
session_config = create_session_config()
|
||||
scanner_instance = Scanner(session_config=session_config)
|
||||
|
||||
# Set the session ID on the scanner for cross-process stop signal management
|
||||
scanner_instance.session_id = session_id
|
||||
# Create scanner WITHOUT socketio to avoid weakref issues
|
||||
scanner_instance = Scanner(session_config=session_config, socketio=None)
|
||||
|
||||
# Prepare scanner for storage (removes problematic objects)
|
||||
scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id)
|
||||
|
||||
session_data = {
|
||||
'scanner': scanner_instance,
|
||||
@ -89,12 +142,24 @@ class SessionManager:
|
||||
'status': 'active'
|
||||
}
|
||||
|
||||
# Serialize the entire session data dictionary using pickle
|
||||
serialized_data = pickle.dumps(session_data)
|
||||
# Test serialization before storing to catch issues early
|
||||
try:
|
||||
test_serialization = pickle.dumps(session_data)
|
||||
print(f"Session serialization test successful ({len(test_serialization)} bytes)")
|
||||
except Exception as pickle_error:
|
||||
print(f"PICKLE TEST FAILED: {pickle_error}")
|
||||
# Try to identify the problematic object
|
||||
for key, value in session_data.items():
|
||||
try:
|
||||
pickle.dumps(value)
|
||||
print(f" {key}: OK")
|
||||
except Exception as item_error:
|
||||
print(f" {key}: FAILED - {item_error}")
|
||||
raise pickle_error
|
||||
|
||||
# Store in Redis
|
||||
session_key = self._get_session_key(session_id)
|
||||
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
self.redis_client.setex(session_key, self.session_timeout, test_serialization)
|
||||
|
||||
# Initialize stop signal as False
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
@ -106,6 +171,8 @@ class SessionManager:
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to create session {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def set_stop_signal(self, session_id: str) -> bool:
|
||||
@ -175,31 +242,63 @@ class SessionManager:
|
||||
# Ensure the scanner has the correct session ID for stop signal checking
|
||||
if 'scanner' in session_data and session_data['scanner']:
|
||||
session_data['scanner'].session_id = session_id
|
||||
# FIXED: Restore socketio connection from our registry
|
||||
socketio_conn = self.get_socketio_connection(session_id)
|
||||
if socketio_conn:
|
||||
session_data['scanner'].socketio = socketio_conn
|
||||
print(f"Restored socketio connection for session {session_id}")
|
||||
else:
|
||||
print(f"No socketio connection found for session {session_id}")
|
||||
session_data['scanner'].socketio = None
|
||||
return session_data
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get session data for {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Serializes and saves session data back to Redis with updated TTL.
|
||||
FIXED: Now preserves socketio connection during storage.
|
||||
|
||||
Returns:
|
||||
bool: True if save was successful
|
||||
"""
|
||||
try:
|
||||
session_key = self._get_session_key(session_id)
|
||||
serialized_data = pickle.dumps(session_data)
|
||||
|
||||
# Create a deep copy to avoid modifying the original scanner object
|
||||
session_data_to_save = copy.deepcopy(session_data)
|
||||
|
||||
# Prepare scanner for storage if it exists
|
||||
if 'scanner' in session_data_to_save and session_data_to_save['scanner']:
|
||||
# FIXED: Preserve the original socketio connection before preparing for storage
|
||||
original_socketio = session_data_to_save['scanner'].socketio
|
||||
|
||||
session_data_to_save['scanner'] = self._prepare_scanner_for_storage(
|
||||
session_data_to_save['scanner'],
|
||||
session_id
|
||||
)
|
||||
|
||||
# FIXED: If we had a socketio connection, make sure it's registered
|
||||
if original_socketio and session_id not in self.active_socketio_connections:
|
||||
self.register_socketio_connection(session_id, original_socketio)
|
||||
|
||||
serialized_data = pickle.dumps(session_data_to_save)
|
||||
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to save session data for {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
|
||||
"""
|
||||
Updates just the scanner object in a session with immediate persistence.
|
||||
FIXED: Updates just the scanner object in a session with immediate persistence.
|
||||
Now maintains socketio connection throughout the update process.
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
@ -207,21 +306,27 @@ class SessionManager:
|
||||
try:
|
||||
session_data = self._get_session_data(session_id)
|
||||
if session_data:
|
||||
# Ensure scanner has the session ID
|
||||
scanner.session_id = session_id
|
||||
# FIXED: Preserve socketio connection before preparing for storage
|
||||
original_socketio = scanner.socketio
|
||||
|
||||
# Prepare scanner for storage
|
||||
scanner = self._prepare_scanner_for_storage(scanner, session_id)
|
||||
session_data['scanner'] = scanner
|
||||
session_data['last_activity'] = time.time()
|
||||
|
||||
# FIXED: Restore socketio connection after preparation
|
||||
if original_socketio:
|
||||
self.register_socketio_connection(session_id, original_socketio)
|
||||
session_data['scanner'].socketio = original_socketio
|
||||
|
||||
# Immediately save to Redis for GUI updates
|
||||
success = self._save_session_data(session_id, session_data)
|
||||
if success:
|
||||
# Only log occasionally to reduce noise
|
||||
if hasattr(self, '_last_update_log'):
|
||||
if time.time() - self._last_update_log > 5: # Log every 5 seconds max
|
||||
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
||||
self._last_update_log = time.time()
|
||||
else:
|
||||
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
|
||||
self._last_update_log = time.time()
|
||||
else:
|
||||
print(f"WARNING: Failed to save scanner state for session {session_id}")
|
||||
@ -231,6 +336,8 @@ class SessionManager:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def update_scanner_status(self, session_id: str, status: str) -> bool:
|
||||
@ -263,7 +370,7 @@ class SessionManager:
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Scanner]:
|
||||
"""
|
||||
Get scanner instance for a session from Redis with session ID management.
|
||||
FIXED: Get scanner instance for a session from Redis with proper socketio restoration.
|
||||
"""
|
||||
if not session_id:
|
||||
return None
|
||||
@ -281,6 +388,15 @@ class SessionManager:
|
||||
if scanner:
|
||||
# Ensure the scanner can check the Redis-based stop signal
|
||||
scanner.session_id = session_id
|
||||
|
||||
# FIXED: Restore socketio connection from our registry
|
||||
socketio_conn = self.get_socketio_connection(session_id)
|
||||
if socketio_conn:
|
||||
scanner.socketio = socketio_conn
|
||||
print(f"✓ Restored socketio connection for session {session_id}")
|
||||
else:
|
||||
scanner.socketio = None
|
||||
print(f"⚠️ No socketio connection found for session {session_id}")
|
||||
|
||||
return scanner
|
||||
|
||||
@ -333,6 +449,12 @@ class SessionManager:
|
||||
# Wait a moment for graceful shutdown
|
||||
time.sleep(0.5)
|
||||
|
||||
# FIXED: Clean up socketio connection
|
||||
with self.lock:
|
||||
if session_id in self.active_socketio_connections:
|
||||
del self.active_socketio_connections[session_id]
|
||||
print(f"Cleaned up socketio connection for session {session_id}")
|
||||
|
||||
# Delete session data and stop signal from Redis
|
||||
session_key = self._get_session_key(session_id)
|
||||
stop_key = self._get_stop_signal_key(session_id)
|
||||
@ -344,6 +466,8 @@ class SessionManager:
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to terminate session {session_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _cleanup_loop(self) -> None:
|
||||
@ -364,6 +488,12 @@ class SessionManager:
|
||||
self.redis_client.delete(stop_key)
|
||||
print(f"Cleaned up orphaned stop signal for session {session_id}")
|
||||
|
||||
# Also clean up socketio connection
|
||||
with self.lock:
|
||||
if session_id in self.active_socketio_connections:
|
||||
del self.active_socketio_connections[session_id]
|
||||
print(f"Cleaned up orphaned socketio for session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup loop: {e}")
|
||||
|
||||
@ -387,14 +517,16 @@ class SessionManager:
|
||||
return {
|
||||
'total_active_sessions': active_sessions,
|
||||
'running_scans': running_scans,
|
||||
'total_stop_signals': len(stop_keys)
|
||||
'total_stop_signals': len(stop_keys),
|
||||
'active_socketio_connections': len(self.active_socketio_connections)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get statistics: {e}")
|
||||
return {
|
||||
'total_active_sessions': 0,
|
||||
'running_scans': 0,
|
||||
'total_stop_signals': 0
|
||||
'total_stop_signals': 0,
|
||||
'active_socketio_connections': 0
|
||||
}
|
||||
|
||||
# Global session manager instance
|
||||
|
||||
@ -15,6 +15,7 @@ class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for all DNSRecon data providers.
|
||||
Now supports session-specific configuration and returns standardized ProviderResult objects.
|
||||
FIXED: Enhanced pickle support to prevent weakref serialization errors.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
||||
@ -53,22 +54,57 @@ class BaseProvider(ABC):
|
||||
def __getstate__(self):
|
||||
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
||||
state = self.__dict__.copy()
|
||||
# Exclude the unpickleable '_local' attribute and stop event
|
||||
unpicklable_attrs = ['_local', '_stop_event']
|
||||
|
||||
# Exclude unpickleable attributes that may contain weakrefs
|
||||
unpicklable_attrs = [
|
||||
'_local', # Thread-local storage (contains requests.Session)
|
||||
'_stop_event', # Threading event
|
||||
'logger', # Logger may contain weakrefs in handlers
|
||||
]
|
||||
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
|
||||
# Also handle any potential weakrefs in the config object
|
||||
if 'config' in state and hasattr(state['config'], '__getstate__'):
|
||||
# If config has its own pickle support, let it handle itself
|
||||
pass
|
||||
elif 'config' in state:
|
||||
# Otherwise, ensure config doesn't contain unpicklable objects
|
||||
try:
|
||||
# Test if config can be pickled
|
||||
import pickle
|
||||
pickle.dumps(state['config'])
|
||||
except (TypeError, AttributeError):
|
||||
# If config can't be pickled, we'll recreate it during unpickling
|
||||
state['_config_class'] = type(state['config']).__name__
|
||||
del state['config']
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
|
||||
self.__dict__.update(state)
|
||||
# Re-initialize the '_local' attribute and stop event
|
||||
|
||||
# Re-initialize unpickleable attributes
|
||||
self._local = threading.local()
|
||||
self._stop_event = None
|
||||
self.logger = get_forensic_logger()
|
||||
|
||||
# Recreate config if it was removed during pickling
|
||||
if not hasattr(self, 'config') and hasattr(self, '_config_class'):
|
||||
if self._config_class == 'Config':
|
||||
from config import config as global_config
|
||||
self.config = global_config
|
||||
elif self._config_class == 'SessionConfig':
|
||||
from core.session_config import create_session_config
|
||||
self.config = create_session_config()
|
||||
del self._config_class
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Get or create thread-local requests session."""
|
||||
if not hasattr(self._local, 'session'):
|
||||
self._local.session = requests.Session()
|
||||
self._local.session.headers.update({
|
||||
|
||||
@ -10,6 +10,7 @@ from core.graph_manager import NodeType, GraphManager
|
||||
class CorrelationProvider(BaseProvider):
|
||||
"""
|
||||
A provider that finds correlations between nodes in the graph.
|
||||
FIXED: Enhanced pickle support to prevent weakref issues with graph references.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "correlation", session_config=None):
|
||||
@ -26,6 +27,7 @@ class CorrelationProvider(BaseProvider):
|
||||
'cert_common_name',
|
||||
'cert_validity_period_days',
|
||||
'cert_issuer_name',
|
||||
'cert_serial_number',
|
||||
'cert_entry_timestamp',
|
||||
'cert_not_before',
|
||||
'cert_not_after',
|
||||
@ -37,6 +39,38 @@ class CorrelationProvider(BaseProvider):
|
||||
'query_timestamp',
|
||||
]
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
FIXED: Prepare CorrelationProvider for pickling by excluding graph reference.
|
||||
"""
|
||||
state = super().__getstate__()
|
||||
|
||||
# Remove graph reference to prevent circular dependencies and weakrefs
|
||||
if 'graph' in state:
|
||||
del state['graph']
|
||||
|
||||
# Also handle correlation_index which might contain complex objects
|
||||
if 'correlation_index' in state:
|
||||
# Clear correlation index as it will be rebuilt when needed
|
||||
state['correlation_index'] = {}
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""
|
||||
FIXED: Restore CorrelationProvider after unpickling.
|
||||
"""
|
||||
super().__setstate__(state)
|
||||
|
||||
# Re-initialize graph reference (will be set by scanner)
|
||||
self.graph = None
|
||||
|
||||
# Re-initialize correlation index
|
||||
self.correlation_index = {}
|
||||
|
||||
# Re-compile regex pattern
|
||||
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "correlation"
|
||||
@ -78,13 +112,20 @@ class CorrelationProvider(BaseProvider):
|
||||
def _find_correlations(self, node_id: str) -> ProviderResult:
|
||||
"""
|
||||
Find correlations for a given node.
|
||||
FIXED: Added safety checks to prevent issues when graph is None.
|
||||
"""
|
||||
result = ProviderResult()
|
||||
# FIXED: Ensure self.graph is not None before proceeding.
|
||||
|
||||
# FIXED: Ensure self.graph is not None before proceeding
|
||||
if not self.graph or not self.graph.graph.has_node(node_id):
|
||||
return result
|
||||
|
||||
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||
try:
|
||||
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||
except Exception as e:
|
||||
# If there's any issue accessing the graph, return empty result
|
||||
print(f"Warning: Could not access graph for correlation analysis: {e}")
|
||||
return result
|
||||
|
||||
for attr in node_attributes:
|
||||
attr_name = attr.get('name')
|
||||
@ -133,6 +174,7 @@ class CorrelationProvider(BaseProvider):
|
||||
|
||||
if len(self.correlation_index[attr_value]['nodes']) > 1:
|
||||
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
|
||||
|
||||
return result
|
||||
|
||||
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
|
||||
|
||||
@ -2,38 +2,17 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import psycopg2
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Set, Optional
|
||||
from urllib.parse import quote
|
||||
from datetime import datetime, timezone
|
||||
import requests
|
||||
from psycopg2 import pool
|
||||
|
||||
from .base_provider import BaseProvider
|
||||
from core.provider_result import ProviderResult
|
||||
from utils.helpers import _is_valid_domain
|
||||
from core.logger import get_forensic_logger
|
||||
|
||||
# --- Global Instance for PostgreSQL Connection Pool ---
|
||||
# This pool will be created once per worker process and is not part of the
|
||||
# CrtShProvider instance, thus avoiding pickling errors.
|
||||
db_pool = None
|
||||
try:
|
||||
db_pool = psycopg2.pool.SimpleConnectionPool(
|
||||
1, 5,
|
||||
host='crt.sh',
|
||||
port=5432,
|
||||
user='guest',
|
||||
dbname='certwatch',
|
||||
sslmode='prefer',
|
||||
connect_timeout=60
|
||||
)
|
||||
# Use a generic logger here as this is at the module level
|
||||
get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.")
|
||||
except Exception as e:
|
||||
get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.")
|
||||
|
||||
|
||||
class CrtShProvider(BaseProvider):
|
||||
"""
|
||||
@ -136,51 +115,42 @@ class CrtShProvider(BaseProvider):
|
||||
|
||||
result = ProviderResult()
|
||||
|
||||
try:
|
||||
if cache_status == "fresh":
|
||||
result = self._load_from_cache(cache_file)
|
||||
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
|
||||
if cache_status == "fresh":
|
||||
result = self._load_from_cache(cache_file)
|
||||
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
|
||||
|
||||
else: # "stale" or "not_found"
|
||||
# Query the API for the latest certificates
|
||||
new_raw_certs = self._query_crtsh_api(domain)
|
||||
|
||||
else: # "stale" or "not_found"
|
||||
# Query the API for the latest certificates
|
||||
new_raw_certs = self._query_crtsh(domain)
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
return ProviderResult()
|
||||
|
||||
# Combine with old data if cache is stale
|
||||
if cache_status == "stale":
|
||||
old_raw_certs = self._load_raw_data_from_cache(cache_file)
|
||||
combined_certs = old_raw_certs + new_raw_certs
|
||||
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
return ProviderResult()
|
||||
|
||||
# Combine with old data if cache is stale
|
||||
if cache_status == "stale":
|
||||
old_raw_certs = self._load_raw_data_from_cache(cache_file)
|
||||
combined_certs = old_raw_certs + new_raw_certs
|
||||
|
||||
# Deduplicate the combined list
|
||||
seen_ids = set()
|
||||
unique_certs = []
|
||||
for cert in combined_certs:
|
||||
cert_id = cert.get('id')
|
||||
if cert_id not in seen_ids:
|
||||
unique_certs.append(cert)
|
||||
seen_ids.add(cert_id)
|
||||
|
||||
raw_certificates_to_process = unique_certs
|
||||
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
|
||||
else: # "not_found"
|
||||
raw_certificates_to_process = new_raw_certs
|
||||
# Deduplicate the combined list
|
||||
seen_ids = set()
|
||||
unique_certs = []
|
||||
for cert in combined_certs:
|
||||
cert_id = cert.get('id')
|
||||
if cert_id not in seen_ids:
|
||||
unique_certs.append(cert)
|
||||
seen_ids.add(cert_id)
|
||||
|
||||
# FIXED: Process certificates to create proper domain and CA nodes
|
||||
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
|
||||
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
|
||||
raw_certificates_to_process = unique_certs
|
||||
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
|
||||
else: # "not_found"
|
||||
raw_certificates_to_process = new_raw_certs
|
||||
|
||||
# FIXED: Process certificates to create proper domain and CA nodes
|
||||
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
|
||||
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
|
||||
|
||||
# Save the new result and the raw data to the cache
|
||||
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
|
||||
|
||||
except (requests.exceptions.RequestException, psycopg2.Error) as e:
|
||||
self.logger.logger.error(f"Upstream query failed for {domain}: {e}")
|
||||
if cache_status != "not_found":
|
||||
result = self._load_from_cache(cache_file)
|
||||
self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.")
|
||||
else:
|
||||
raise e # Re-raise if there's no cache to fall back on
|
||||
# Save the new result and the raw data to the cache
|
||||
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
|
||||
|
||||
return result
|
||||
|
||||
@ -277,58 +247,6 @@ class CrtShProvider(BaseProvider):
|
||||
json.dump(cache_data, f, separators=(',', ':'), default=str)
|
||||
except Exception as e:
|
||||
self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}")
|
||||
|
||||
def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]:
|
||||
"""Query crt.sh, trying the database first and falling back to the API."""
|
||||
global db_pool
|
||||
if db_pool:
|
||||
try:
|
||||
self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}")
|
||||
return self._query_crtsh_db(domain)
|
||||
except psycopg2.Error as e:
|
||||
self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.")
|
||||
return self._query_crtsh_api(domain)
|
||||
else:
|
||||
self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}")
|
||||
return self._query_crtsh_api(domain)
|
||||
|
||||
def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]:
|
||||
"""Query crt.sh database for raw certificate data."""
|
||||
global db_pool
|
||||
conn = db_pool.getconn()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
query = """
|
||||
SELECT
|
||||
c.id,
|
||||
x509_serialnumber(c.certificate) as serial_number,
|
||||
x509_notbefore(c.certificate) as not_before,
|
||||
x509_notafter(c.certificate) as not_after,
|
||||
c.issuer_ca_id,
|
||||
ca.name as issuer_name,
|
||||
x509_commonname(c.certificate) as common_name,
|
||||
identities(c.certificate)::text as name_value
|
||||
FROM certificate c
|
||||
LEFT JOIN ca ON c.issuer_ca_id = ca.id
|
||||
WHERE identities(c.certificate) @@ plainto_tsquery(%s)
|
||||
ORDER BY c.id DESC
|
||||
LIMIT 5000;
|
||||
"""
|
||||
cursor.execute(query, (domain,))
|
||||
|
||||
results = []
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
for row in cursor.fetchall():
|
||||
row_dict = dict(zip(columns, row))
|
||||
if row_dict.get('not_before'):
|
||||
row_dict['not_before'] = row_dict['not_before'].isoformat()
|
||||
if row_dict.get('not_after'):
|
||||
row_dict['not_after'] = row_dict['not_after'].isoformat()
|
||||
results.append(row_dict)
|
||||
self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.")
|
||||
return results
|
||||
finally:
|
||||
db_pool.putconn(conn)
|
||||
|
||||
def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]:
|
||||
"""Query crt.sh API for raw certificate data."""
|
||||
|
||||
@ -11,6 +11,7 @@ class DNSProvider(BaseProvider):
|
||||
"""
|
||||
Provider for standard DNS resolution and reverse DNS lookups.
|
||||
Now returns standardized ProviderResult objects with IPv4 and IPv6 support.
|
||||
FIXED: Enhanced pickle support to prevent resolver serialization issues.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, session_config=None):
|
||||
@ -27,6 +28,22 @@ class DNSProvider(BaseProvider):
|
||||
self.resolver.timeout = 5
|
||||
self.resolver.lifetime = 10
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare the object for pickling by excluding resolver."""
|
||||
state = super().__getstate__()
|
||||
# Remove the unpickleable 'resolver' attribute
|
||||
if 'resolver' in state:
|
||||
del state['resolver']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore the object after unpickling by reconstructing resolver."""
|
||||
super().__setstate__(state)
|
||||
# Re-initialize the 'resolver' attribute
|
||||
self.resolver = resolver.Resolver()
|
||||
self.resolver.timeout = 5
|
||||
self.resolver.lifetime = 10
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "dns"
|
||||
@ -106,10 +123,10 @@ class DNSProvider(BaseProvider):
|
||||
if _is_valid_domain(hostname):
|
||||
# Determine appropriate forward relationship type based on IP version
|
||||
if ip_version == 6:
|
||||
relationship_type = 'dns_aaaa_record'
|
||||
relationship_type = 'shodan_aaaa_record'
|
||||
record_prefix = 'AAAA'
|
||||
else:
|
||||
relationship_type = 'dns_a_record'
|
||||
relationship_type = 'shodan_a_record'
|
||||
record_prefix = 'A'
|
||||
|
||||
# Add the relationship
|
||||
|
||||
@ -36,6 +36,15 @@ class ShodanProvider(BaseProvider):
|
||||
self.cache_dir = Path('cache') / 'shodan'
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare the object for pickling."""
|
||||
state = super().__getstate__()
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore the object after unpickling."""
|
||||
super().__setstate__(state)
|
||||
|
||||
def _check_api_connection(self) -> bool:
|
||||
"""
|
||||
FIXED: Lazy connection checking - only test when actually needed.
|
||||
|
||||
@ -8,4 +8,6 @@ dnspython
|
||||
gunicorn
|
||||
redis
|
||||
python-dotenv
|
||||
psycopg2-binary
|
||||
psycopg2-binary
|
||||
Flask-SocketIO
|
||||
eventlet
|
||||
@ -1,15 +1,15 @@
|
||||
/**
|
||||
* Main application logic for DNSRecon web interface
|
||||
* Handles UI interactions, API communication, and data flow
|
||||
* UPDATED: Now compatible with a strictly flat, unified data model for attributes.
|
||||
* FIXED: Enhanced real-time WebSocket graph updates
|
||||
*/
|
||||
|
||||
class DNSReconApp {
|
||||
constructor() {
|
||||
console.log('DNSReconApp constructor called');
|
||||
this.graphManager = null;
|
||||
this.socket = null;
|
||||
this.scanStatus = 'idle';
|
||||
this.pollInterval = null;
|
||||
this.currentSessionId = null;
|
||||
|
||||
this.elements = {};
|
||||
@ -17,6 +17,14 @@ class DNSReconApp {
|
||||
this.isScanning = false;
|
||||
this.lastGraphUpdate = null;
|
||||
|
||||
// FIXED: Add connection state tracking
|
||||
this.isConnected = false;
|
||||
this.reconnectAttempts = 0;
|
||||
this.maxReconnectAttempts = 5;
|
||||
|
||||
// FIXED: Track last graph data for debugging
|
||||
this.lastGraphData = null;
|
||||
|
||||
this.init();
|
||||
}
|
||||
|
||||
@ -31,13 +39,11 @@ class DNSReconApp {
|
||||
this.initializeElements();
|
||||
this.setupEventHandlers();
|
||||
this.initializeGraph();
|
||||
this.updateStatus();
|
||||
this.initializeSocket();
|
||||
this.loadProviders();
|
||||
this.initializeEnhancedModals();
|
||||
this.addCheckboxStyling();
|
||||
|
||||
this.updateGraph();
|
||||
|
||||
console.log('DNSRecon application initialized successfully');
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize DNSRecon application:', error);
|
||||
@ -45,6 +51,162 @@ class DNSReconApp {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
initializeSocket() {
|
||||
console.log('🔌 Initializing WebSocket connection...');
|
||||
|
||||
try {
|
||||
this.socket = io({
|
||||
transports: ['websocket', 'polling'],
|
||||
timeout: 10000,
|
||||
reconnection: true,
|
||||
reconnectionAttempts: 5,
|
||||
reconnectionDelay: 2000
|
||||
});
|
||||
|
||||
this.socket.on('connect', () => {
|
||||
console.log('✅ WebSocket connected successfully');
|
||||
this.isConnected = true;
|
||||
this.reconnectAttempts = 0;
|
||||
this.updateConnectionStatus('idle');
|
||||
|
||||
console.log('📡 Requesting initial status...');
|
||||
this.socket.emit('get_status');
|
||||
});
|
||||
|
||||
this.socket.on('disconnect', (reason) => {
|
||||
console.log('❌ WebSocket disconnected:', reason);
|
||||
this.isConnected = false;
|
||||
this.updateConnectionStatus('error');
|
||||
});
|
||||
|
||||
this.socket.on('connect_error', (error) => {
|
||||
console.error('❌ WebSocket connection error:', error);
|
||||
this.reconnectAttempts++;
|
||||
this.updateConnectionStatus('error');
|
||||
|
||||
if (this.reconnectAttempts >= 5) {
|
||||
this.showError('WebSocket connection failed. Please refresh the page.');
|
||||
}
|
||||
});
|
||||
|
||||
this.socket.on('reconnect', (attemptNumber) => {
|
||||
console.log('✅ WebSocket reconnected after', attemptNumber, 'attempts');
|
||||
this.isConnected = true;
|
||||
this.reconnectAttempts = 0;
|
||||
this.updateConnectionStatus('idle');
|
||||
this.socket.emit('get_status');
|
||||
});
|
||||
|
||||
// FIXED: Enhanced scan_update handler with detailed graph processing and debugging
|
||||
this.socket.on('scan_update', (data) => {
|
||||
console.log('📨 WebSocket update received:', {
|
||||
status: data.status,
|
||||
target: data.target_domain,
|
||||
progress: data.progress_percentage,
|
||||
graphNodes: data.graph?.nodes?.length || 0,
|
||||
graphEdges: data.graph?.edges?.length || 0,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
|
||||
try {
|
||||
// Handle status change
|
||||
if (data.status !== this.scanStatus) {
|
||||
console.log(`📄 Status change: ${this.scanStatus} → ${data.status}`);
|
||||
this.handleStatusChange(data.status, data.task_queue_size);
|
||||
}
|
||||
this.scanStatus = data.status;
|
||||
|
||||
// Update status display
|
||||
this.updateStatusDisplay(data);
|
||||
|
||||
// FIXED: Always update graph if data is present and graph manager exists
|
||||
if (data.graph && this.graphManager) {
|
||||
console.log('📊 Processing graph update:', {
|
||||
nodes: data.graph.nodes?.length || 0,
|
||||
edges: data.graph.edges?.length || 0,
|
||||
hasNodes: Array.isArray(data.graph.nodes),
|
||||
hasEdges: Array.isArray(data.graph.edges),
|
||||
isInitialized: this.graphManager.isInitialized
|
||||
});
|
||||
|
||||
// FIXED: Initialize graph manager if not already done
|
||||
if (!this.graphManager.isInitialized) {
|
||||
console.log('🎯 Initializing graph manager...');
|
||||
this.graphManager.initialize();
|
||||
}
|
||||
|
||||
// FIXED: Force graph update and verify it worked
|
||||
const previousNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0;
|
||||
const previousEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0;
|
||||
|
||||
console.log('🔄 Before update - Nodes:', previousNodeCount, 'Edges:', previousEdgeCount);
|
||||
|
||||
// Store the data for debugging
|
||||
this.lastGraphData = data.graph;
|
||||
|
||||
// Update the graph
|
||||
this.graphManager.updateGraph(data.graph);
|
||||
this.lastGraphUpdate = Date.now();
|
||||
|
||||
// Verify the update worked
|
||||
const newNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0;
|
||||
const newEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0;
|
||||
|
||||
console.log('🔄 After update - Nodes:', newNodeCount, 'Edges:', newEdgeCount);
|
||||
|
||||
if (newNodeCount !== data.graph.nodes.length || newEdgeCount !== data.graph.edges.length) {
|
||||
console.warn('⚠️ Graph update mismatch!', {
|
||||
expectedNodes: data.graph.nodes.length,
|
||||
actualNodes: newNodeCount,
|
||||
expectedEdges: data.graph.edges.length,
|
||||
actualEdges: newEdgeCount
|
||||
});
|
||||
|
||||
// Force a complete rebuild if there's a mismatch
|
||||
console.log('🔧 Force rebuilding graph...');
|
||||
this.graphManager.clear();
|
||||
this.graphManager.updateGraph(data.graph);
|
||||
}
|
||||
|
||||
console.log('✅ Graph updated successfully');
|
||||
|
||||
// FIXED: Force network redraw if we're using vis.js
|
||||
if (this.graphManager.network) {
|
||||
try {
|
||||
this.graphManager.network.redraw();
|
||||
console.log('🎨 Network redrawn');
|
||||
} catch (redrawError) {
|
||||
console.warn('⚠️ Network redraw failed:', redrawError);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if (!data.graph) {
|
||||
console.log('⚠️ No graph data in WebSocket update');
|
||||
}
|
||||
if (!this.graphManager) {
|
||||
console.log('⚠️ Graph manager not available');
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('❌ Error processing WebSocket update:', error);
|
||||
console.error('Update data:', data);
|
||||
console.error('Stack trace:', error.stack);
|
||||
}
|
||||
});
|
||||
|
||||
this.socket.on('error', (error) => {
|
||||
console.error('❌ WebSocket error:', error);
|
||||
this.showError('WebSocket communication error');
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('❌ Failed to initialize WebSocket:', error);
|
||||
this.showError('Failed to establish real-time connection');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize DOM element references
|
||||
@ -263,12 +425,36 @@ class DNSReconApp {
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize graph visualization
|
||||
* FIXED: Initialize graph visualization with enhanced debugging
|
||||
*/
|
||||
initializeGraph() {
|
||||
try {
|
||||
console.log('Initializing graph manager...');
|
||||
this.graphManager = new GraphManager('network-graph');
|
||||
|
||||
// FIXED: Add debugging hooks to graph manager
|
||||
if (this.graphManager) {
|
||||
// Override updateGraph to add debugging
|
||||
const originalUpdateGraph = this.graphManager.updateGraph.bind(this.graphManager);
|
||||
this.graphManager.updateGraph = (graphData) => {
|
||||
console.log('🔧 GraphManager.updateGraph called with:', {
|
||||
nodes: graphData?.nodes?.length || 0,
|
||||
edges: graphData?.edges?.length || 0,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
|
||||
const result = originalUpdateGraph(graphData);
|
||||
|
||||
console.log('🔧 GraphManager.updateGraph completed, network state:', {
|
||||
networkExists: !!this.graphManager.network,
|
||||
nodeDataSetLength: this.graphManager.nodes?.length || 0,
|
||||
edgeDataSetLength: this.graphManager.edges?.length || 0
|
||||
});
|
||||
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
console.log('Graph manager initialized successfully');
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize graph manager:', error);
|
||||
@ -288,7 +474,6 @@ class DNSReconApp {
|
||||
|
||||
console.log(`Target: "${target}", Max depth: ${maxDepth}`);
|
||||
|
||||
// Validation
|
||||
if (!target) {
|
||||
console.log('Validation failed: empty target');
|
||||
this.showError('Please enter a target domain or IP');
|
||||
@ -303,6 +488,19 @@ class DNSReconApp {
|
||||
return;
|
||||
}
|
||||
|
||||
// FIXED: Ensure WebSocket connection before starting scan
|
||||
if (!this.isConnected) {
|
||||
console.log('WebSocket not connected, attempting to connect...');
|
||||
this.socket.connect();
|
||||
|
||||
// Wait a moment for connection
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
if (!this.isConnected) {
|
||||
this.showWarning('WebSocket connection not established. Updates may be delayed.');
|
||||
}
|
||||
}
|
||||
|
||||
console.log('Validation passed, setting UI state to scanning...');
|
||||
this.setUIState('scanning');
|
||||
this.showInfo('Starting reconnaissance scan...');
|
||||
@ -320,23 +518,28 @@ class DNSReconApp {
|
||||
|
||||
if (response.success) {
|
||||
this.currentSessionId = response.scan_id;
|
||||
this.showSuccess('Reconnaissance scan started successfully');
|
||||
this.showSuccess('Reconnaissance scan started - watching for real-time updates');
|
||||
|
||||
if (clearGraph) {
|
||||
if (clearGraph && this.graphManager) {
|
||||
console.log('🧹 Clearing graph for new scan');
|
||||
this.graphManager.clear();
|
||||
}
|
||||
|
||||
console.log(`Scan started for ${target} with depth ${maxDepth}`);
|
||||
console.log(`✅ Scan started for ${target} with depth ${maxDepth}`);
|
||||
|
||||
// Start polling immediately with faster interval for responsiveness
|
||||
this.startPolling(1000);
|
||||
|
||||
// Force an immediate status update
|
||||
console.log('Forcing immediate status update...');
|
||||
setTimeout(() => {
|
||||
this.updateStatus();
|
||||
this.updateGraph();
|
||||
}, 100);
|
||||
// FIXED: Immediately start listening for updates
|
||||
if (this.socket && this.isConnected) {
|
||||
console.log('📡 Requesting initial status update...');
|
||||
this.socket.emit('get_status');
|
||||
|
||||
// Set up periodic status requests as backup (every 5 seconds during scan)
|
||||
/*this.statusRequestInterval = setInterval(() => {
|
||||
if (this.isScanning && this.socket && this.isConnected) {
|
||||
console.log('📡 Periodic status request...');
|
||||
this.socket.emit('get_status');
|
||||
}
|
||||
}, 5000);*/
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new Error(response.error || 'Failed to start scan');
|
||||
@ -348,20 +551,23 @@ class DNSReconApp {
|
||||
this.setUIState('idle');
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Scan stop with immediate UI feedback
|
||||
*/
|
||||
|
||||
// FIXED: Enhanced stop scan with interval cleanup
|
||||
async stopScan() {
|
||||
try {
|
||||
console.log('Stopping scan...');
|
||||
|
||||
// Immediately disable stop button and show stopping state
|
||||
// Clear status request interval
|
||||
/*if (this.statusRequestInterval) {
|
||||
clearInterval(this.statusRequestInterval);
|
||||
this.statusRequestInterval = null;
|
||||
}*/
|
||||
|
||||
if (this.elements.stopScan) {
|
||||
this.elements.stopScan.disabled = true;
|
||||
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOPPING]</span><span>Stopping...</span>';
|
||||
}
|
||||
|
||||
// Show immediate feedback
|
||||
this.showInfo('Stopping scan...');
|
||||
|
||||
const response = await this.apiCall('/api/scan/stop', 'POST');
|
||||
@ -369,21 +575,10 @@ class DNSReconApp {
|
||||
if (response.success) {
|
||||
this.showSuccess('Scan stop requested');
|
||||
|
||||
// Force immediate status update
|
||||
setTimeout(() => {
|
||||
this.updateStatus();
|
||||
}, 100);
|
||||
|
||||
// Continue polling for a bit to catch the status change
|
||||
this.startPolling(500); // Fast polling to catch status change
|
||||
|
||||
// Stop fast polling after 10 seconds
|
||||
setTimeout(() => {
|
||||
if (this.scanStatus === 'stopped' || this.scanStatus === 'idle') {
|
||||
this.stopPolling();
|
||||
}
|
||||
}, 10000);
|
||||
|
||||
// Request final status update
|
||||
if (this.socket && this.isConnected) {
|
||||
setTimeout(() => this.socket.emit('get_status'), 500);
|
||||
}
|
||||
} else {
|
||||
throw new Error(response.error || 'Failed to stop scan');
|
||||
}
|
||||
@ -392,7 +587,6 @@ class DNSReconApp {
|
||||
console.error('Failed to stop scan:', error);
|
||||
this.showError(`Failed to stop scan: ${error.message}`);
|
||||
|
||||
// Re-enable stop button on error
|
||||
if (this.elements.stopScan) {
|
||||
this.elements.stopScan.disabled = false;
|
||||
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>';
|
||||
@ -549,85 +743,24 @@ class DNSReconApp {
|
||||
}
|
||||
|
||||
/**
|
||||
* Start polling for scan updates with configurable interval
|
||||
*/
|
||||
startPolling(interval = 2000) {
|
||||
console.log('=== STARTING POLLING ===');
|
||||
|
||||
if (this.pollInterval) {
|
||||
console.log('Clearing existing poll interval');
|
||||
clearInterval(this.pollInterval);
|
||||
}
|
||||
|
||||
this.pollInterval = setInterval(() => {
|
||||
this.updateStatus();
|
||||
this.updateGraph();
|
||||
this.loadProviders();
|
||||
}, interval);
|
||||
|
||||
console.log(`Polling started with ${interval}ms interval`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop polling for updates
|
||||
*/
|
||||
stopPolling() {
|
||||
console.log('=== STOPPING POLLING ===');
|
||||
if (this.pollInterval) {
|
||||
clearInterval(this.pollInterval);
|
||||
this.pollInterval = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Status update with better error handling
|
||||
*/
|
||||
async updateStatus() {
|
||||
try {
|
||||
const response = await this.apiCall('/api/scan/status');
|
||||
|
||||
|
||||
if (response.success && response.status) {
|
||||
const status = response.status;
|
||||
|
||||
this.updateStatusDisplay(status);
|
||||
|
||||
// Handle status changes
|
||||
if (status.status !== this.scanStatus) {
|
||||
console.log(`*** STATUS CHANGED: ${this.scanStatus} -> ${status.status} ***`);
|
||||
this.handleStatusChange(status.status, status.task_queue_size);
|
||||
}
|
||||
|
||||
this.scanStatus = status.status;
|
||||
} else {
|
||||
console.error('Status update failed:', response);
|
||||
// Don't show error for status updates to avoid spam
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Failed to update status:', error);
|
||||
this.showConnectionError();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update graph from server
|
||||
* FIXED: Update graph from server with enhanced debugging
|
||||
*/
|
||||
async updateGraph() {
|
||||
try {
|
||||
console.log('Updating graph...');
|
||||
console.log('Updating graph via API call...');
|
||||
const response = await this.apiCall('/api/graph');
|
||||
|
||||
|
||||
if (response.success) {
|
||||
const graphData = response.graph;
|
||||
|
||||
console.log('Graph data received:');
|
||||
console.log('Graph data received from API:');
|
||||
console.log('- Nodes:', graphData.nodes ? graphData.nodes.length : 0);
|
||||
console.log('- Edges:', graphData.edges ? graphData.edges.length : 0);
|
||||
|
||||
// FIXED: Always update graph, even if empty - let GraphManager handle placeholder
|
||||
if (this.graphManager) {
|
||||
console.log('🔧 Calling GraphManager.updateGraph from API response...');
|
||||
this.graphManager.updateGraph(graphData);
|
||||
this.lastGraphUpdate = Date.now();
|
||||
|
||||
@ -636,6 +769,8 @@ class DNSReconApp {
|
||||
if (this.elements.relationshipsDisplay) {
|
||||
this.elements.relationshipsDisplay.textContent = edgeCount;
|
||||
}
|
||||
|
||||
console.log('✅ Manual graph update completed');
|
||||
}
|
||||
} else {
|
||||
console.error('Graph update failed:', response);
|
||||
@ -731,48 +866,70 @@ class DNSReconApp {
|
||||
* @param {string} newStatus - New scan status
|
||||
*/
|
||||
handleStatusChange(newStatus, task_queue_size) {
|
||||
console.log(`=== STATUS CHANGE: ${this.scanStatus} -> ${newStatus} ===`);
|
||||
console.log(`📄 Status change handler: ${this.scanStatus} → ${newStatus}`);
|
||||
|
||||
switch (newStatus) {
|
||||
case 'running':
|
||||
this.setUIState('scanning', task_queue_size);
|
||||
this.showSuccess('Scan is running');
|
||||
// Increase polling frequency for active scans
|
||||
this.startPolling(1000); // Poll every 1 second for running scans
|
||||
this.showSuccess('Scan is running - updates in real-time');
|
||||
this.updateConnectionStatus('active');
|
||||
break;
|
||||
|
||||
case 'completed':
|
||||
this.setUIState('completed', task_queue_size);
|
||||
this.stopPolling();
|
||||
this.showSuccess('Scan completed successfully');
|
||||
this.updateConnectionStatus('completed');
|
||||
this.loadProviders();
|
||||
// Force a final graph update
|
||||
console.log('Scan completed - forcing final graph update');
|
||||
setTimeout(() => this.updateGraph(), 100);
|
||||
console.log('✅ Scan completed - requesting final graph update');
|
||||
// Request final status to ensure we have the complete graph
|
||||
setTimeout(() => {
|
||||
if (this.socket && this.isConnected) {
|
||||
this.socket.emit('get_status');
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
// Clear status request interval
|
||||
/*if (this.statusRequestInterval) {
|
||||
clearInterval(this.statusRequestInterval);
|
||||
this.statusRequestInterval = null;
|
||||
}*/
|
||||
break;
|
||||
|
||||
case 'failed':
|
||||
this.setUIState('failed', task_queue_size);
|
||||
this.stopPolling();
|
||||
this.showError('Scan failed');
|
||||
this.updateConnectionStatus('error');
|
||||
this.loadProviders();
|
||||
|
||||
// Clear status request interval
|
||||
/*if (this.statusRequestInterval) {
|
||||
clearInterval(this.statusRequestInterval);
|
||||
this.statusRequestInterval = null;
|
||||
}*/
|
||||
break;
|
||||
|
||||
case 'stopped':
|
||||
this.setUIState('stopped', task_queue_size);
|
||||
this.stopPolling();
|
||||
this.showSuccess('Scan stopped');
|
||||
this.updateConnectionStatus('stopped');
|
||||
this.loadProviders();
|
||||
|
||||
// Clear status request interval
|
||||
if (this.statusRequestInterval) {
|
||||
clearInterval(this.statusRequestInterval);
|
||||
this.statusRequestInterval = null;
|
||||
}
|
||||
break;
|
||||
|
||||
case 'idle':
|
||||
this.setUIState('idle', task_queue_size);
|
||||
this.stopPolling();
|
||||
this.updateConnectionStatus('idle');
|
||||
|
||||
// Clear status request interval
|
||||
/*if (this.statusRequestInterval) {
|
||||
clearInterval(this.statusRequestInterval);
|
||||
this.statusRequestInterval = null;
|
||||
}*/
|
||||
break;
|
||||
|
||||
default:
|
||||
@ -824,6 +981,7 @@ class DNSReconApp {
|
||||
if (this.graphManager) {
|
||||
this.graphManager.isScanning = true;
|
||||
}
|
||||
|
||||
if (this.elements.startScan) {
|
||||
this.elements.startScan.disabled = true;
|
||||
this.elements.startScan.classList.add('loading');
|
||||
@ -851,6 +1009,7 @@ class DNSReconApp {
|
||||
if (this.graphManager) {
|
||||
this.graphManager.isScanning = false;
|
||||
}
|
||||
|
||||
if (this.elements.startScan) {
|
||||
this.elements.startScan.disabled = !isQueueEmpty;
|
||||
this.elements.startScan.classList.remove('loading');
|
||||
@ -1093,7 +1252,7 @@ class DNSReconApp {
|
||||
} else {
|
||||
// API key not configured - ALWAYS show input field
|
||||
const statusClass = info.enabled ? 'enabled' : 'api-key-required';
|
||||
const statusText = info.enabled ? '○ Ready for API Key' : '⚠️ API Key Required';
|
||||
const statusText = info.enabled ? '◯ Ready for API Key' : '⚠️ API Key Required';
|
||||
|
||||
inputGroup.innerHTML = `
|
||||
<div class="provider-header">
|
||||
@ -2033,10 +2192,10 @@ class DNSReconApp {
|
||||
|
||||
// If the scanner was idle, it's now running. Start polling to see the new node appear.
|
||||
if (this.scanStatus === 'idle') {
|
||||
this.startPolling(1000);
|
||||
this.socket.emit('get_status');
|
||||
} else {
|
||||
// If already scanning, force a quick graph update to see the change sooner.
|
||||
setTimeout(() => this.updateGraph(), 500);
|
||||
setTimeout(() => this.socket.emit('get_status'), 500);
|
||||
}
|
||||
|
||||
} else {
|
||||
@ -2075,8 +2234,8 @@ class DNSReconApp {
|
||||
*/
|
||||
getNodeTypeIcon(nodeType) {
|
||||
const icons = {
|
||||
'domain': '🌍',
|
||||
'ip': '📍',
|
||||
'domain': '🌐',
|
||||
'ip': '🔢',
|
||||
'asn': '🏢',
|
||||
'large_entity': '📦',
|
||||
'correlation_object': '🔗'
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
<title>DNSRecon - Infrastructure Reconnaissance</title>
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='css/main.css') }}">
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.21.0/vis.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.js"></script>
|
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.21.0/vis.min.css" rel="stylesheet" type="text/css">
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@300;400;500;700&family=Special+Elite&display=swap"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user