Compare commits

...

2 Commits

Author SHA1 Message Date
overcuriousity
c4e6a8998a iteration on ws implementation 2025-09-20 16:52:05 +02:00
overcuriousity
75a595c9cb try to implement websockets 2025-09-20 14:17:17 +02:00
13 changed files with 1306 additions and 507 deletions

169
app.py
View File

@ -3,9 +3,9 @@
"""
Flask application entry point for DNSRecon web interface.
Provides REST API endpoints and serves the web interface with user session support.
FIXED: Enhanced WebSocket integration with proper connection management.
"""
import json
import traceback
from flask import Flask, render_template, request, jsonify, send_file, session
from datetime import datetime, timezone, timedelta
@ -13,6 +13,7 @@ import io
import os
from core.session_manager import session_manager
from flask_socketio import SocketIO
from config import config
from core.graph_manager import NodeType
from utils.helpers import is_valid_target
@ -21,29 +22,38 @@ from decimal import Decimal
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*")
app.config['SECRET_KEY'] = config.flask_secret_key
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours)
def get_user_scanner():
"""
Retrieves the scanner for the current session, or creates a new one if none exists.
FIXED: Retrieves the scanner for the current session with proper socketio management.
"""
current_flask_session_id = session.get('dnsrecon_session_id')
if current_flask_session_id:
existing_scanner = session_manager.get_session(current_flask_session_id)
if existing_scanner:
# FIXED: Ensure socketio is properly maintained
existing_scanner.socketio = socketio
print(f"✓ Retrieved existing scanner for session {current_flask_session_id[:8]}... with socketio restored")
return current_flask_session_id, existing_scanner
new_session_id = session_manager.create_session()
# FIXED: Register socketio connection when creating new session
new_session_id = session_manager.create_session(socketio)
new_scanner = session_manager.get_session(new_session_id)
if not new_scanner:
raise Exception("Failed to create new scanner session")
# FIXED: Ensure new scanner has socketio reference and register the connection
new_scanner.socketio = socketio
session_manager.register_socketio_connection(new_session_id, socketio)
session['dnsrecon_session_id'] = new_session_id
session.permanent = True
print(f"✓ Created new scanner for session {new_session_id[:8]}... with socketio registered")
return new_session_id, new_scanner
@ -56,7 +66,7 @@ def index():
@app.route('/api/scan/start', methods=['POST'])
def start_scan():
"""
Starts a new reconnaissance scan.
FIXED: Starts a new reconnaissance scan with proper socketio management.
"""
try:
data = request.get_json()
@ -80,9 +90,17 @@ def start_scan():
if not scanner:
return jsonify({'success': False, 'error': 'Failed to get scanner instance.'}), 500
# FIXED: Ensure scanner has socketio reference and is registered
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
print(f"🚀 Starting scan for {target} with socketio enabled and registered")
success = scanner.start_scan(target, max_depth, clear_graph=clear_graph, force_rescan_target=force_rescan_target)
if success:
# Update session with socketio-enabled scanner
session_manager.update_session_scanner(user_session_id, scanner)
return jsonify({
'success': True,
'message': 'Reconnaissance scan started successfully',
@ -111,6 +129,10 @@ def stop_scan():
if not scanner.session_id:
scanner.session_id = user_session_id
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
scanner.stop_scan()
session_manager.set_stop_signal(user_session_id)
session_manager.update_scanner_status(user_session_id, 'stopped')
@ -127,37 +149,83 @@ def stop_scan():
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/scan/status', methods=['GET'])
@socketio.on('connect')
def handle_connect():
"""
FIXED: Handle WebSocket connection with proper session management.
"""
print(f'✓ WebSocket client connected: {request.sid}')
# Try to restore existing session connection
current_flask_session_id = session.get('dnsrecon_session_id')
if current_flask_session_id:
# Register this socketio connection for the existing session
session_manager.register_socketio_connection(current_flask_session_id, socketio)
print(f'✓ Registered WebSocket for existing session: {current_flask_session_id[:8]}...')
# Immediately send current status to new connection
get_scan_status()
@socketio.on('disconnect')
def handle_disconnect():
"""
FIXED: Handle WebSocket disconnection gracefully.
"""
print(f'✗ WebSocket client disconnected: {request.sid}')
# Note: We don't immediately remove the socketio connection from session_manager
# because the user might reconnect. The cleanup will happen during session cleanup.
@socketio.on('get_status')
def get_scan_status():
"""Get current scan status."""
"""
FIXED: Get current scan status and emit real-time update with proper error handling.
"""
try:
user_session_id, scanner = get_user_scanner()
if not scanner:
return jsonify({
'success': True,
'status': {
'status': 'idle', 'target_domain': None, 'current_depth': 0,
'max_depth': 0, 'progress_percentage': 0.0,
'user_session_id': user_session_id
}
})
status = {
'status': 'idle',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'progress_percentage': 0.0,
'user_session_id': user_session_id,
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
print(f"📡 Emitting idle status for session {user_session_id[:8] if user_session_id else 'none'}...")
else:
if not scanner.session_id:
scanner.session_id = user_session_id
if not scanner.session_id:
scanner.session_id = user_session_id
# FIXED: Ensure scanner has socketio reference for future updates
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
status = scanner.get_scan_status()
status['user_session_id'] = user_session_id
status = scanner.get_scan_status()
status['user_session_id'] = user_session_id
return jsonify({'success': True, 'status': status})
print(f"📡 Emitting status update: {status['status']} - "
f"Nodes: {len(status.get('graph', {}).get('nodes', []))}, "
f"Edges: {len(status.get('graph', {}).get('edges', []))}")
# Update session with socketio-enabled scanner
session_manager.update_session_scanner(user_session_id, scanner)
socketio.emit('scan_update', status)
except Exception as e:
traceback.print_exc()
return jsonify({
'success': False, 'error': f'Internal server error: {str(e)}',
'fallback_status': {'status': 'error', 'progress_percentage': 0.0}
}), 500
error_status = {
'status': 'error',
'message': 'Failed to get status',
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
print(f"⚠️ Error getting status, emitting error status")
socketio.emit('scan_update', error_status)
@app.route('/api/graph', methods=['GET'])
@ -174,6 +242,10 @@ def get_graph_data():
if not scanner:
return jsonify({'success': True, 'graph': empty_graph, 'user_session_id': user_session_id})
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
graph_data = scanner.get_graph_data() or empty_graph
return jsonify({'success': True, 'graph': graph_data, 'user_session_id': user_session_id})
@ -200,6 +272,10 @@ def extract_from_large_entity():
if not scanner:
return jsonify({'success': False, 'error': 'No active session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
success = scanner.extract_node_from_large_entity(large_entity_id, node_id)
if success:
@ -220,6 +296,10 @@ def delete_graph_node(node_id):
if not scanner:
return jsonify({'success': False, 'error': 'No active session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
success = scanner.graph.remove_node(node_id)
if success:
@ -245,6 +325,10 @@ def revert_graph_action():
if not scanner:
return jsonify({'success': False, 'error': 'No active session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
action_type = data['type']
action_data = data['data']
@ -289,6 +373,10 @@ def export_results():
if not scanner:
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Get export data using the new export manager
try:
results = export_manager.export_scan_results(scanner)
@ -340,6 +428,10 @@ def export_targets():
if not scanner:
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Use export manager for targets export
targets_txt = export_manager.export_targets_list(scanner)
@ -370,6 +462,10 @@ def export_summary():
if not scanner:
return jsonify({'success': False, 'error': 'No active scanner session found'}), 404
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Use export manager for summary generation
summary_txt = export_manager.generate_executive_summary(scanner)
@ -402,6 +498,10 @@ def set_api_keys():
user_session_id, scanner = get_user_scanner()
session_config = scanner.config
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
updated_providers = []
for provider_name, api_key in data.items():
@ -434,6 +534,10 @@ def get_providers():
user_session_id, scanner = get_user_scanner()
base_provider_info = scanner.get_provider_info()
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
# Enhance provider info with API key source information
enhanced_provider_info = {}
@ -498,6 +602,10 @@ def configure_providers():
user_session_id, scanner = get_user_scanner()
session_config = scanner.config
# FIXED: Ensure scanner has socketio reference
scanner.socketio = socketio
session_manager.register_socketio_connection(user_session_id, socketio)
updated_providers = []
for provider_name, settings in data.items():
@ -526,7 +634,6 @@ def configure_providers():
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.errorhandler(404)
def not_found(error):
"""Handle 404 errors."""
@ -542,9 +649,9 @@ def internal_error(error):
if __name__ == '__main__':
config.load_from_env()
app.run(
host=config.flask_host,
port=config.flask_port,
debug=config.flask_debug,
threaded=True
)
print("🚀 Starting DNSRecon with enhanced WebSocket support...")
print(f" Host: {config.flask_host}")
print(f" Port: {config.flask_port}")
print(f" Debug: {config.flask_debug}")
print(" WebSocket: Enhanced connection management enabled")
socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug)

View File

@ -4,8 +4,7 @@
Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata.
Now fully compatible with the unified ProviderResult data model.
UPDATED: Fixed correlation exclusion keys to match actual attribute names.
UPDATED: Removed export_json() method - now handled by ExportManager.
FIXED: Added proper pickle support to prevent weakref serialization errors.
"""
import re
from datetime import datetime, timezone
@ -33,6 +32,7 @@ class GraphManager:
Thread-safe graph manager for DNSRecon infrastructure mapping.
Uses NetworkX for in-memory graph storage with confidence scoring.
Compatible with unified ProviderResult data model.
FIXED: Added proper pickle support to handle NetworkX graph serialization.
"""
def __init__(self):
@ -41,6 +41,57 @@ class GraphManager:
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time
def __getstate__(self):
"""Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
state = self.__dict__.copy()
# Convert NetworkX graph to a serializable format
if hasattr(self, 'graph') and self.graph:
# Extract all nodes with their data
nodes_data = {}
for node_id, attrs in self.graph.nodes(data=True):
nodes_data[node_id] = dict(attrs)
# Extract all edges with their data
edges_data = []
for source, target, attrs in self.graph.edges(data=True):
edges_data.append({
'source': source,
'target': target,
'attributes': dict(attrs)
})
# Replace the NetworkX graph with serializable data
state['_graph_nodes'] = nodes_data
state['_graph_edges'] = edges_data
del state['graph']
return state
def __setstate__(self, state):
"""Restore GraphManager after unpickling by reconstructing NetworkX graph."""
# Restore basic attributes
self.__dict__.update(state)
# Reconstruct NetworkX graph from serializable data
self.graph = nx.DiGraph()
# Restore nodes
if hasattr(self, '_graph_nodes'):
for node_id, attrs in self._graph_nodes.items():
self.graph.add_node(node_id, **attrs)
del self._graph_nodes
# Restore edges
if hasattr(self, '_graph_edges'):
for edge_data in self._graph_edges:
self.graph.add_edge(
edge_data['source'],
edge_data['target'],
**edge_data['attributes']
)
del self._graph_edges
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
"""

View File

@ -40,6 +40,7 @@ class ForensicLogger:
"""
Thread-safe forensic logging system for DNSRecon.
Maintains detailed audit trail of all reconnaissance activities.
FIXED: Enhanced pickle support to prevent weakref issues in logging handlers.
"""
def __init__(self, session_id: str = ""):
@ -65,45 +66,74 @@ class ForensicLogger:
'target_domains': set()
}
# Configure standard logger
# Configure standard logger with simple setup to avoid weakrefs
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
self.logger.setLevel(logging.INFO)
# Create formatter for structured logging
# Create minimal formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Add console handler if not already present
# Add console handler only if not already present (avoid duplicate handlers)
if not self.logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
def __getstate__(self):
"""Prepare ForensicLogger for pickling by excluding unpicklable objects."""
"""
FIXED: Prepare ForensicLogger for pickling by excluding problematic objects.
"""
state = self.__dict__.copy()
# Remove the unpickleable 'logger' attribute
if 'logger' in state:
del state['logger']
if 'lock' in state:
del state['lock']
# Remove potentially unpickleable attributes that may contain weakrefs
unpicklable_attrs = ['logger', 'lock']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
# Convert sets to lists for JSON serialization compatibility
if 'session_metadata' in state:
metadata = state['session_metadata'].copy()
if 'providers_used' in metadata and isinstance(metadata['providers_used'], set):
metadata['providers_used'] = list(metadata['providers_used'])
if 'target_domains' in metadata and isinstance(metadata['target_domains'], set):
metadata['target_domains'] = list(metadata['target_domains'])
state['session_metadata'] = metadata
return state
def __setstate__(self, state):
"""Restore ForensicLogger after unpickling by reconstructing logger."""
"""
FIXED: Restore ForensicLogger after unpickling by reconstructing components.
"""
self.__dict__.update(state)
# Re-initialize the 'logger' attribute
# Re-initialize threading lock
self.lock = threading.Lock()
# Re-initialize logger with minimal setup
self.logger = logging.getLogger(f'dnsrecon.{self.session_id}')
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Only add handler if not already present
if not self.logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.lock = threading.Lock()
# Convert lists back to sets if needed
if 'session_metadata' in self.__dict__:
metadata = self.session_metadata
if 'providers_used' in metadata and isinstance(metadata['providers_used'], list):
metadata['providers_used'] = set(metadata['providers_used'])
if 'target_domains' in metadata and isinstance(metadata['target_domains'], list):
metadata['target_domains'] = set(metadata['target_domains'])
def _generate_session_id(self) -> str:
"""Generate unique session identifier."""
@ -143,18 +173,23 @@ class ForensicLogger:
discovery_context=discovery_context
)
self.api_requests.append(api_request)
self.session_metadata['total_requests'] += 1
self.session_metadata['providers_used'].add(provider)
with self.lock:
self.api_requests.append(api_request)
self.session_metadata['total_requests'] += 1
self.session_metadata['providers_used'].add(provider)
if target_indicator:
self.session_metadata['target_domains'].add(target_indicator)
if target_indicator:
self.session_metadata['target_domains'].add(target_indicator)
# Log to standard logger
if error:
self.logger.error(f"API Request Failed.")
else:
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
# Log to standard logger with error handling
try:
if error:
self.logger.error(f"API Request Failed - {provider}: {url}")
else:
self.logger.info(f"API Request - {provider}: {url} - Status: {status_code}")
except Exception:
# If logging fails, continue without breaking the application
pass
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str, confidence_score: float,
@ -183,29 +218,44 @@ class ForensicLogger:
discovery_method=discovery_method
)
self.relationships.append(relationship)
self.session_metadata['total_relationships'] += 1
with self.lock:
self.relationships.append(relationship)
self.session_metadata['total_relationships'] += 1
self.logger.info(
f"Relationship Discovered - {source_node} -> {target_node} "
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
)
# Log to standard logger with error handling
try:
self.logger.info(
f"Relationship Discovered - {source_node} -> {target_node} "
f"({relationship_type}) - Confidence: {confidence_score:.2f} - Provider: {provider}"
)
except Exception:
# If logging fails, continue without breaking the application
pass
def log_scan_start(self, target_domain: str, recursion_depth: int,
enabled_providers: List[str]) -> None:
"""Log the start of a reconnaissance scan."""
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
try:
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
self.session_metadata['target_domains'].update(target_domain)
with self.lock:
self.session_metadata['target_domains'].add(target_domain)
except Exception:
pass
def log_scan_complete(self) -> None:
"""Log the completion of a reconnaissance scan."""
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
with self.lock:
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
# Convert sets to lists for serialization
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
self.logger.info(f"Scan Complete - Session: {self.session_id}")
try:
self.logger.info(f"Scan Complete - Session: {self.session_id}")
except Exception:
pass
def export_audit_trail(self) -> Dict[str, Any]:
"""
@ -214,12 +264,13 @@ class ForensicLogger:
Returns:
Dictionary containing complete session audit trail
"""
return {
'session_metadata': self.session_metadata.copy(),
'api_requests': [asdict(req) for req in self.api_requests],
'relationships': [asdict(rel) for rel in self.relationships],
'export_timestamp': datetime.now(timezone.utc).isoformat()
}
with self.lock:
return {
'session_metadata': self.session_metadata.copy(),
'api_requests': [asdict(req) for req in self.api_requests],
'relationships': [asdict(rel) for rel in self.relationships],
'export_timestamp': datetime.now(timezone.utc).isoformat()
}
def get_forensic_summary(self) -> Dict[str, Any]:
"""
@ -229,7 +280,13 @@ class ForensicLogger:
Dictionary containing summary statistics
"""
provider_stats = {}
for provider in self.session_metadata['providers_used']:
# Ensure providers_used is a set for iteration
providers_used = self.session_metadata['providers_used']
if isinstance(providers_used, list):
providers_used = set(providers_used)
for provider in providers_used:
provider_requests = [req for req in self.api_requests if req.provider == provider]
provider_relationships = [rel for rel in self.relationships if rel.provider == provider]

View File

@ -6,7 +6,6 @@ import os
import importlib
import redis
import time
import math
import random # Imported for jitter
from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor
@ -36,9 +35,10 @@ class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
UNIFIED: Combines comprehensive features with improved display formatting.
FIXED: Enhanced threading object initialization to prevent None references.
"""
def __init__(self, session_config=None):
def __init__(self, session_config=None, socketio=None):
"""Initialize scanner with session-specific configuration."""
try:
# Use provided session config or create default
@ -46,6 +46,12 @@ class Scanner:
from core.session_config import create_session_config
session_config = create_session_config()
# FIXED: Initialize all threading objects first
self._initialize_threading_objects()
# Set socketio (but will be set to None for storage)
self.socketio = socketio
self.config = session_config
self.graph = GraphManager()
self.providers = []
@ -53,17 +59,12 @@ class Scanner:
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event()
self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
self.initial_targets = set()
# Thread-safe processing tracking (from Document 1)
self.currently_processing = set()
self.processing_lock = threading.Lock()
# Display-friendly processing list (from Document 2)
self.currently_processing_display = []
@ -81,9 +82,10 @@ class Scanner:
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Status logger thread with improved formatting
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
# Initialize collections that will be recreated during unpickling
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
# Initialize providers with session config
self._initialize_providers()
@ -99,12 +101,24 @@ class Scanner:
traceback.print_exc()
raise
def _initialize_threading_objects(self):
"""
FIXED: Initialize all threading objects with proper error handling.
This method can be called during both __init__ and __setstate__.
"""
self.stop_event = threading.Event()
self.processing_lock = threading.Lock()
self.status_logger_stop_event = threading.Event()
self.status_logger_thread = None
def _is_stop_requested(self) -> bool:
"""
Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries.
FIXED: Added None check for stop_event.
"""
if self.stop_event.is_set():
# FIXED: Ensure stop_event exists before checking
if hasattr(self, 'stop_event') and self.stop_event and self.stop_event.is_set():
return True
if self.session_id:
@ -112,16 +126,24 @@ class Scanner:
from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id)
except Exception as e:
# Fall back to local event
return self.stop_event.is_set()
# Fall back to local event if it exists
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
return self.stop_event.is_set()
# Final fallback
if hasattr(self, 'stop_event') and self.stop_event:
return self.stop_event.is_set()
return False
def _set_stop_signal(self) -> None:
"""
Set stop signal both locally and in Redis.
FIXED: Added None check for stop_event.
"""
self.stop_event.set()
# FIXED: Ensure stop_event exists before setting
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.set()
if self.session_id:
try:
@ -143,7 +165,8 @@ class Scanner:
'rate_limiter',
'logger',
'status_logger_thread',
'status_logger_stop_event'
'status_logger_stop_event',
'socketio'
]
for attr in unpicklable_attrs:
@ -161,16 +184,21 @@ class Scanner:
"""Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
self.stop_event = threading.Event()
# FIXED: Ensure all threading objects are properly initialized
self._initialize_threading_objects()
# Re-initialize other objects
self.scan_thread = None
self.executor = None
self.processing_lock = threading.Lock()
self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger()
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
# FIXED: Initialize socketio as None but preserve ability to set it
if not hasattr(self, 'socketio'):
self.socketio = None
# Initialize missing attributes with defaults
if not hasattr(self, 'providers') or not self.providers:
self._initialize_providers()
@ -180,11 +208,36 @@ class Scanner:
if not hasattr(self, 'currently_processing_display'):
self.currently_processing_display = []
if not hasattr(self, 'target_retries'):
self.target_retries = defaultdict(int)
if not hasattr(self, 'scan_failed_due_to_retries'):
self.scan_failed_due_to_retries = False
if not hasattr(self, 'initial_targets'):
self.initial_targets = set()
# Ensure providers have stop events
if hasattr(self, 'providers'):
for provider in self.providers:
if hasattr(provider, 'set_stop_event'):
if hasattr(provider, 'set_stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
def _ensure_threading_objects_exist(self):
"""
FIXED: Utility method to ensure threading objects exist before use.
Call this before any method that might use threading objects.
"""
if not hasattr(self, 'stop_event') or self.stop_event is None:
print("WARNING: Threading objects not initialized, recreating...")
self._initialize_threading_objects()
if not hasattr(self, 'processing_lock') or self.processing_lock is None:
self.processing_lock = threading.Lock()
if not hasattr(self, 'task_queue') or self.task_queue is None:
self.task_queue = PriorityQueue()
def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration."""
self.providers = []
@ -222,7 +275,9 @@ class Scanner:
print(f" Available: {is_available}")
if is_available:
provider.set_stop_event(self.stop_event)
# FIXED: Ensure stop_event exists before setting it
if hasattr(self, 'stop_event') and self.stop_event:
provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph)
self.providers.append(provider)
@ -252,15 +307,25 @@ class Scanner:
BOLD = "\033[1m"
last_status_str = ""
while not self.status_logger_stop_event.is_set():
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
while not (hasattr(self, 'status_logger_stop_event') and
self.status_logger_stop_event and
self.status_logger_stop_event.is_set()):
try:
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
# FIXED: Check if processing_lock exists before using
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
else:
in_flight_tasks = list(getattr(self, 'currently_processing', []))
status_str = (
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0}{ENDC} | "
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
f"Skipped: {self.tasks_skipped} | "
@ -288,22 +353,30 @@ class Scanner:
time.sleep(2)
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
"""
FIXED: Enhanced start_scan with proper threading object initialization and socketio management.
"""
# FIXED: Ensure threading objects exist before proceeding
self._ensure_threading_objects_exist()
if self.scan_thread and self.scan_thread.is_alive():
self.logger.logger.info("Stopping existing scan before starting new one")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# Clean up processing state
with self.processing_lock:
self.currently_processing.clear()
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# Clear task queue
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
if hasattr(self, 'task_queue') and self.task_queue:
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
# Shutdown executor
if self.executor:
@ -320,14 +393,26 @@ class Scanner:
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
self.status = ScanStatus.IDLE
self.stop_event.clear()
# FIXED: Ensure stop_event exists before clearing
if hasattr(self, 'stop_event') and self.stop_event:
self.stop_event.clear()
if self.session_id:
from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
with self.processing_lock:
self.currently_processing.clear()
# FIXED: Restore socketio connection if missing
if not hasattr(self, 'socketio') or not self.socketio:
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"✓ Restored socketio connection for scan start")
# FIXED: Safe cleanup with existence checks
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
@ -395,7 +480,10 @@ class Scanner:
)
self.scan_thread.start()
self.status_logger_stop_event.clear()
# FIXED: Ensure status_logger_stop_event exists before clearing
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread(
target=self._status_logger_thread,
daemon=True,
@ -449,6 +537,13 @@ class Scanner:
return 10 # Very low rate limit = very low priority
def _execute_scan(self, target: str, max_depth: int) -> None:
"""
FIXED: Enhanced execute_scan with proper threading object handling.
"""
# FIXED: Ensure threading objects exist
self._ensure_threading_objects_exist()
update_counter = 0 # Track updates for throttling
last_update_time = time.time()
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
@ -480,8 +575,13 @@ class Scanner:
print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested():
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
@ -534,10 +634,23 @@ class Scanner:
continue
# Thread-safe processing state management
with self.processing_lock:
processing_key = (provider_name, target_item)
# FIXED: Safe processing lock usage
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested():
break
processing_key = (provider_name, target_item)
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
@ -556,7 +669,12 @@ class Scanner:
if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
update_counter += 1
current_time = time.time()
if (update_counter % 5 == 0) or (current_time - last_update_time > 3.0):
self._update_session_state()
last_update_time = current_time
update_counter = 0
if self._is_stop_requested():
break
@ -601,9 +719,13 @@ class Scanner:
self.indicators_completed += 1
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
# FIXED: Safe processing lock usage for cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
# PHASE 2: Run correlations on all discovered nodes
if not self._is_stop_requested():
@ -616,8 +738,9 @@ class Scanner:
self.logger.logger.error(f"Scan failed: {e}")
finally:
# Comprehensive cleanup (same as before)
with self.processing_lock:
self.currently_processing.clear()
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
while not self.task_queue.empty():
@ -633,7 +756,9 @@ class Scanner:
else:
self.status = ScanStatus.COMPLETED
self.status_logger_stop_event.set()
# FIXED: Safe stop event handling
if hasattr(self, 'status_logger_stop_event') and self.status_logger_stop_event:
self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0)
@ -687,8 +812,13 @@ class Scanner:
while not self._is_stop_requested() and correlation_tasks:
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
# FIXED: Safe processing check
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
else:
no_active_processing = len(getattr(self, 'currently_processing', [])) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
@ -720,10 +850,23 @@ class Scanner:
correlation_tasks.remove(task_tuple)
continue
with self.processing_lock:
processing_key = (provider_name, target_item)
# FIXED: Safe processing management
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
if self._is_stop_requested():
break
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
else:
if self._is_stop_requested():
break
processing_key = (provider_name, target_item)
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
@ -752,16 +895,165 @@ class Scanner:
correlation_tasks.remove(task_tuple)
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.discard(processing_key)
else:
if hasattr(self, 'currently_processing'):
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
# Rest of the methods remain the same but with similar threading object safety checks...
# I'll include the key methods that need fixes:
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# FIXED: Safe cleanup
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive graph data for real-time updates."""
try:
# FIXED: Safe processing state access
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
else:
currently_processing_count = len(getattr(self, 'currently_processing', []))
currently_processing_list = list(getattr(self, 'currently_processing', []))
# FIXED: Always include complete graph data for real-time updates
graph_data = self.get_graph_data()
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph': graph_data, # FIXED: Always include complete graph data
'task_queue_size': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0,
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception as e:
traceback.print_exc()
return {
'status': 'error',
'message': 'Failed to get status',
'graph': {'nodes': [], 'edges': [], 'statistics': {'node_count': 0, 'edge_count': 0}}
}
def _update_session_state(self) -> None:
"""
FIXED: Update the scanner state in Redis and emit real-time WebSocket updates.
Enhanced with better error handling and socketio management.
"""
if self.session_id:
try:
# Get current status for WebSocket emission
current_status = self.get_scan_status()
# FIXED: Emit real-time update via WebSocket with better error handling
socketio_available = False
if hasattr(self, 'socketio') and self.socketio:
try:
print(f"📡 EMITTING WebSocket update: {current_status.get('status', 'unknown')} - "
f"Nodes: {len(current_status.get('graph', {}).get('nodes', []))}, "
f"Edges: {len(current_status.get('graph', {}).get('edges', []))}")
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted successfully")
socketio_available = True
except Exception as ws_error:
print(f"⚠️ WebSocket emission failed: {ws_error}")
# Try to get socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print("🔄 Attempting to use registered socketio connection...")
registered_socketio.emit('scan_update', current_status)
self.socketio = registered_socketio # Update our reference
print("✅ WebSocket update emitted via registered connection")
socketio_available = True
else:
print("⚠️ No registered socketio connection found")
except Exception as fallback_error:
print(f"⚠️ Fallback socketio emission also failed: {fallback_error}")
else:
# Try to restore socketio from session manager
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
print(f"🔄 Restoring socketio connection for session {self.session_id}")
self.socketio = registered_socketio
self.socketio.emit('scan_update', current_status)
print("✅ WebSocket update emitted via restored connection")
socketio_available = True
else:
print(f"⚠️ No socketio connection available for session {self.session_id}")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio connection: {restore_error}")
if not socketio_available:
print(f"⚠️ Real-time updates unavailable for session {self.session_id}")
# Update session in Redis for persistence (always do this)
try:
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception as redis_error:
print(f"⚠️ Failed to update session in Redis: {redis_error}")
except Exception as e:
print(f"⚠️ Error updating session state: {e}")
import traceback
traceback.print_exc()
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
"""
Manages the entire process for a given target and provider.
This version is generalized to handle all relationships dynamically.
FIXED: Manages the entire process for a given target and provider with enhanced real-time updates.
"""
if self._is_stop_requested():
return set(), set(), False
@ -781,22 +1073,36 @@ class Scanner:
if provider_result is None:
provider_successful = False
elif not self._is_stop_requested():
# Pass all relationships to be processed
discovered, is_large_entity = self._process_provider_result_unified(
target, provider, provider_result, depth
)
new_targets.update(discovered)
# FIXED: Emit real-time update after processing provider result
if discovered or provider_result.get_relationship_count() > 0:
# Ensure we have socketio connection for real-time updates
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio connection during provider processing")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio during provider processing: {restore_error}")
self._update_session_state()
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
return new_targets, set(), provider_successful
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
"""
The "worker" function that directly communicates with the provider to fetch data.
"""
"""The "worker" function that directly communicates with the provider to fetch data."""
provider_name = provider.get_name()
start_time = datetime.now(timezone.utc)
@ -823,9 +1129,7 @@ class Scanner:
def _create_large_entity_from_result(self, source_node: str, provider_name: str,
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
"""
Creates a large entity node, tags all member nodes, and returns its ID and members.
"""
"""Creates a large entity node, tags all member nodes, and returns its ID and members."""
members = {rel.target_node for rel in provider_result.relationships
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
@ -858,7 +1162,7 @@ class Scanner:
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
"""
Removes a node from a large entity, allowing it to be processed normally.
FIXED: Removes a node from a large entity with immediate real-time update.
"""
if not self.graph.graph.has_node(node_id):
return False
@ -877,7 +1181,6 @@ class Scanner:
for provider in eligible_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
# Use current depth of the large entity if available, else 0
depth = 0
if self.graph.graph.has_node(large_entity_id):
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
@ -888,6 +1191,19 @@ class Scanner:
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
self.total_tasks_ever_enqueued += 1
# FIXED: Emit real-time update after extraction with socketio management
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for node extraction update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for extraction: {restore_error}")
self._update_session_state()
return True
return False
@ -895,8 +1211,7 @@ class Scanner:
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
"""
Process a unified ProviderResult object to update the graph.
This version dynamically re-routes edges to a large entity container.
FIXED: Process a unified ProviderResult object to update the graph with enhanced real-time updates.
"""
provider_name = provider.get_name()
discovered_targets = set()
@ -916,6 +1231,10 @@ class Scanner:
target, provider_name, provider_result, current_depth
)
# Track if we added anything significant
nodes_added = 0
edges_added = 0
for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested():
break
@ -947,17 +1266,20 @@ class Scanner:
max_depth_reached = current_depth >= self.max_depth
# Add actual nodes to the graph (they might be hidden by the UI)
self.graph.add_node(source_node_id, source_type)
self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached})
if self.graph.add_node(source_node_id, source_type):
nodes_added += 1
if self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached}):
nodes_added += 1
# Add the visual edge to the graph
self.graph.add_edge(
if self.graph.add_edge(
visual_source, visual_target,
relationship.relationship_type,
relationship.confidence,
provider_name,
relationship.raw_data
)
):
edges_added += 1
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
if target_node_id not in large_entity_members:
@ -985,86 +1307,32 @@ class Scanner:
if not self.graph.graph.has_node(node_id):
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
nodes_added += 1
else:
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
return discovered_targets, is_large_entity
# FIXED: Emit real-time update if we added anything significant
if nodes_added > 0 or edges_added > 0:
print(f"🔄 Added {nodes_added} nodes, {edges_added} edges - triggering real-time update")
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
# Ensure we have socketio connection for immediate update
if self.session_id and (not hasattr(self, 'socketio') or not self.socketio):
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
from core.session_manager import session_manager
registered_socketio = session_manager.get_socketio_connection(self.session_id)
if registered_socketio:
self.socketio = registered_socketio
print(f"🔄 Restored socketio for immediate update")
except Exception as restore_error:
print(f"⚠️ Failed to restore socketio for immediate update: {restore_error}")
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.
"""
if self.session_id:
try:
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception:
pass
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive processing information."""
try:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics(),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize(),
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception:
traceback.print_exc()
return { 'status': 'error', 'message': 'Failed to get status' }
return discovered_targets, is_large_entity
def _initialize_provider_states(self, target: str) -> None:
"""
FIXED: Safer provider state initialization with error handling.
"""
"""FIXED: Safer provider state initialization with error handling."""
try:
if not self.graph.graph.has_node(target):
return
@ -1077,11 +1345,8 @@ class Scanner:
except Exception as e:
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
"""
FIXED: Improved provider eligibility checking with better filtering.
"""
"""FIXED: Improved provider eligibility checking with better filtering."""
if dns_only:
return [p for p in self.providers if p.get_name() == 'dns']
@ -1120,9 +1385,7 @@ class Scanner:
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""
FIXED: More robust check for already queried providers with proper error handling.
"""
"""FIXED: More robust check for already queried providers with proper error handling."""
try:
if not self.graph.graph.has_node(target):
return False
@ -1141,9 +1404,7 @@ class Scanner:
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None:
"""
FIXED: More robust provider state updates with validation.
"""
"""FIXED: More robust provider state updates with validation."""
try:
if not self.graph.graph.has_node(target):
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
@ -1170,7 +1431,8 @@ class Scanner:
}
# Update last modified time for forensic integrity
self.last_modified = datetime.now(timezone.utc).isoformat()
if hasattr(self, 'last_modified'):
self.last_modified = datetime.now(timezone.utc).isoformat()
except Exception as e:
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
@ -1187,9 +1449,14 @@ class Scanner:
return 0.0
# Add small buffer for tasks still in queue to avoid showing 100% too early
queue_size = max(0, self.task_queue.qsize())
with self.processing_lock:
active_tasks = len(self.currently_processing)
queue_size = max(0, self.task_queue.qsize() if hasattr(self, 'task_queue') and self.task_queue else 0)
# FIXED: Safe processing count
if hasattr(self, 'processing_lock') and self.processing_lock:
with self.processing_lock:
active_tasks = len(self.currently_processing)
else:
active_tasks = len(getattr(self, 'currently_processing', []))
# Adjust total to account for remaining work
adjusted_total = max(self.total_tasks_ever_enqueued,
@ -1206,12 +1473,13 @@ class Scanner:
return 0.0
def get_graph_data(self) -> Dict[str, Any]:
"""Get current graph data formatted for frontend visualization."""
graph_data = self.graph.get_graph_data()
graph_data['initial_targets'] = list(self.initial_targets)
return graph_data
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
"""Get comprehensive information about all available providers."""
info = {}
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):

View File

@ -6,6 +6,7 @@ import uuid
import redis
import pickle
from typing import Dict, Optional, Any
import copy
from core.scanner import Scanner
from config import config
@ -13,7 +14,7 @@ from config import config
class SessionManager:
"""
FIXED: Manages multiple scanner instances for concurrent user sessions using Redis.
Now more conservative about session creation to preserve API keys and configuration.
Enhanced to properly maintain WebSocket connections throughout scan lifecycle.
"""
def __init__(self, session_timeout_minutes: int = 0):
@ -30,6 +31,9 @@ class SessionManager:
# FIXED: Add a creation lock to prevent race conditions
self.creation_lock = threading.Lock()
# Track active socketio connections per session
self.active_socketio_connections = {}
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
@ -40,7 +44,7 @@ class SessionManager:
"""Prepare SessionManager for pickling."""
state = self.__dict__.copy()
# Exclude unpickleable attributes - Redis client and threading objects
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock']
unpicklable_attrs = ['lock', 'cleanup_thread', 'redis_client', 'creation_lock', 'active_socketio_connections']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
@ -53,6 +57,7 @@ class SessionManager:
self.redis_client = redis.StrictRedis(db=0, decode_responses=False)
self.lock = threading.Lock()
self.creation_lock = threading.Lock()
self.active_socketio_connections = {}
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
@ -64,22 +69,70 @@ class SessionManager:
"""Generates the Redis key for a session's stop signal."""
return f"dnsrecon:stop:{session_id}"
def create_session(self) -> str:
def register_socketio_connection(self, session_id: str, socketio) -> None:
"""
FIXED: Create a new user session with thread-safe creation to prevent duplicates.
FIXED: Register a socketio connection for a session.
This ensures the connection is maintained throughout the session lifecycle.
"""
with self.lock:
self.active_socketio_connections[session_id] = socketio
print(f"Registered socketio connection for session {session_id}")
def get_socketio_connection(self, session_id: str):
"""
FIXED: Get the active socketio connection for a session.
"""
with self.lock:
return self.active_socketio_connections.get(session_id)
def _prepare_scanner_for_storage(self, scanner: Scanner, session_id: str) -> Scanner:
"""
FIXED: Prepare scanner for storage by ensuring proper cleanup of unpicklable objects.
Now preserves socketio connection info for restoration.
"""
# Set the session ID on the scanner for cross-process stop signal management
scanner.session_id = session_id
# FIXED: Don't set socketio to None if we want to preserve real-time updates
# Instead, we'll restore it when loading the scanner
scanner.socketio = None
# Force cleanup of any threading objects that might cause issues
if hasattr(scanner, 'stop_event'):
scanner.stop_event = None
if hasattr(scanner, 'scan_thread'):
scanner.scan_thread = None
if hasattr(scanner, 'executor'):
scanner.executor = None
if hasattr(scanner, 'status_logger_thread'):
scanner.status_logger_thread = None
if hasattr(scanner, 'status_logger_stop_event'):
scanner.status_logger_stop_event = None
return scanner
def create_session(self, socketio=None) -> str:
"""
FIXED: Create a new user session with enhanced WebSocket management.
"""
# FIXED: Use creation lock to prevent race conditions
with self.creation_lock:
session_id = str(uuid.uuid4())
print(f"=== CREATING SESSION {session_id} IN REDIS ===")
# FIXED: Register socketio connection first
if socketio:
self.register_socketio_connection(session_id, socketio)
try:
from core.session_config import create_session_config
session_config = create_session_config()
scanner_instance = Scanner(session_config=session_config)
# Set the session ID on the scanner for cross-process stop signal management
scanner_instance.session_id = session_id
# Create scanner WITHOUT socketio to avoid weakref issues
scanner_instance = Scanner(session_config=session_config, socketio=None)
# Prepare scanner for storage (removes problematic objects)
scanner_instance = self._prepare_scanner_for_storage(scanner_instance, session_id)
session_data = {
'scanner': scanner_instance,
@ -89,12 +142,24 @@ class SessionManager:
'status': 'active'
}
# Serialize the entire session data dictionary using pickle
serialized_data = pickle.dumps(session_data)
# Test serialization before storing to catch issues early
try:
test_serialization = pickle.dumps(session_data)
print(f"Session serialization test successful ({len(test_serialization)} bytes)")
except Exception as pickle_error:
print(f"PICKLE TEST FAILED: {pickle_error}")
# Try to identify the problematic object
for key, value in session_data.items():
try:
pickle.dumps(value)
print(f" {key}: OK")
except Exception as item_error:
print(f" {key}: FAILED - {item_error}")
raise pickle_error
# Store in Redis
session_key = self._get_session_key(session_id)
self.redis_client.setex(session_key, self.session_timeout, serialized_data)
self.redis_client.setex(session_key, self.session_timeout, test_serialization)
# Initialize stop signal as False
stop_key = self._get_stop_signal_key(session_id)
@ -106,6 +171,8 @@ class SessionManager:
except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}")
import traceback
traceback.print_exc()
raise
def set_stop_signal(self, session_id: str) -> bool:
@ -175,31 +242,63 @@ class SessionManager:
# Ensure the scanner has the correct session ID for stop signal checking
if 'scanner' in session_data and session_data['scanner']:
session_data['scanner'].session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
session_data['scanner'].socketio = socketio_conn
print(f"Restored socketio connection for session {session_id}")
else:
print(f"No socketio connection found for session {session_id}")
session_data['scanner'].socketio = None
return session_data
return None
except Exception as e:
print(f"ERROR: Failed to get session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return None
def _save_session_data(self, session_id: str, session_data: Dict[str, Any]) -> bool:
"""
Serializes and saves session data back to Redis with updated TTL.
FIXED: Now preserves socketio connection during storage.
Returns:
bool: True if save was successful
"""
try:
session_key = self._get_session_key(session_id)
serialized_data = pickle.dumps(session_data)
# Create a deep copy to avoid modifying the original scanner object
session_data_to_save = copy.deepcopy(session_data)
# Prepare scanner for storage if it exists
if 'scanner' in session_data_to_save and session_data_to_save['scanner']:
# FIXED: Preserve the original socketio connection before preparing for storage
original_socketio = session_data_to_save['scanner'].socketio
session_data_to_save['scanner'] = self._prepare_scanner_for_storage(
session_data_to_save['scanner'],
session_id
)
# FIXED: If we had a socketio connection, make sure it's registered
if original_socketio and session_id not in self.active_socketio_connections:
self.register_socketio_connection(session_id, original_socketio)
serialized_data = pickle.dumps(session_data_to_save)
result = self.redis_client.setex(session_key, self.session_timeout, serialized_data)
return result
except Exception as e:
print(f"ERROR: Failed to save session data for {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def update_session_scanner(self, session_id: str, scanner: 'Scanner') -> bool:
"""
Updates just the scanner object in a session with immediate persistence.
FIXED: Updates just the scanner object in a session with immediate persistence.
Now maintains socketio connection throughout the update process.
Returns:
bool: True if update was successful
@ -207,21 +306,27 @@ class SessionManager:
try:
session_data = self._get_session_data(session_id)
if session_data:
# Ensure scanner has the session ID
scanner.session_id = session_id
# FIXED: Preserve socketio connection before preparing for storage
original_socketio = scanner.socketio
# Prepare scanner for storage
scanner = self._prepare_scanner_for_storage(scanner, session_id)
session_data['scanner'] = scanner
session_data['last_activity'] = time.time()
# FIXED: Restore socketio connection after preparation
if original_socketio:
self.register_socketio_connection(session_id, original_socketio)
session_data['scanner'].socketio = original_socketio
# Immediately save to Redis for GUI updates
success = self._save_session_data(session_id, session_data)
if success:
# Only log occasionally to reduce noise
if hasattr(self, '_last_update_log'):
if time.time() - self._last_update_log > 5: # Log every 5 seconds max
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
self._last_update_log = time.time()
else:
#print(f"Scanner state updated for session {session_id} (status: {scanner.status})")
self._last_update_log = time.time()
else:
print(f"WARNING: Failed to save scanner state for session {session_id}")
@ -231,6 +336,8 @@ class SessionManager:
return False
except Exception as e:
print(f"ERROR: Failed to update scanner for session {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def update_scanner_status(self, session_id: str, status: str) -> bool:
@ -263,7 +370,7 @@ class SessionManager:
def get_session(self, session_id: str) -> Optional[Scanner]:
"""
Get scanner instance for a session from Redis with session ID management.
FIXED: Get scanner instance for a session from Redis with proper socketio restoration.
"""
if not session_id:
return None
@ -282,6 +389,15 @@ class SessionManager:
# Ensure the scanner can check the Redis-based stop signal
scanner.session_id = session_id
# FIXED: Restore socketio connection from our registry
socketio_conn = self.get_socketio_connection(session_id)
if socketio_conn:
scanner.socketio = socketio_conn
print(f"✓ Restored socketio connection for session {session_id}")
else:
scanner.socketio = None
print(f"⚠️ No socketio connection found for session {session_id}")
return scanner
def get_session_status_only(self, session_id: str) -> Optional[str]:
@ -333,6 +449,12 @@ class SessionManager:
# Wait a moment for graceful shutdown
time.sleep(0.5)
# FIXED: Clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up socketio connection for session {session_id}")
# Delete session data and stop signal from Redis
session_key = self._get_session_key(session_id)
stop_key = self._get_stop_signal_key(session_id)
@ -344,6 +466,8 @@ class SessionManager:
except Exception as e:
print(f"ERROR: Failed to terminate session {session_id}: {e}")
import traceback
traceback.print_exc()
return False
def _cleanup_loop(self) -> None:
@ -364,6 +488,12 @@ class SessionManager:
self.redis_client.delete(stop_key)
print(f"Cleaned up orphaned stop signal for session {session_id}")
# Also clean up socketio connection
with self.lock:
if session_id in self.active_socketio_connections:
del self.active_socketio_connections[session_id]
print(f"Cleaned up orphaned socketio for session {session_id}")
except Exception as e:
print(f"Error in cleanup loop: {e}")
@ -387,14 +517,16 @@ class SessionManager:
return {
'total_active_sessions': active_sessions,
'running_scans': running_scans,
'total_stop_signals': len(stop_keys)
'total_stop_signals': len(stop_keys),
'active_socketio_connections': len(self.active_socketio_connections)
}
except Exception as e:
print(f"ERROR: Failed to get statistics: {e}")
return {
'total_active_sessions': 0,
'running_scans': 0,
'total_stop_signals': 0
'total_stop_signals': 0,
'active_socketio_connections': 0
}
# Global session manager instance

View File

@ -15,6 +15,7 @@ class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Now supports session-specific configuration and returns standardized ProviderResult objects.
FIXED: Enhanced pickle support to prevent weakref serialization errors.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
@ -53,22 +54,57 @@ class BaseProvider(ABC):
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
unpicklable_attrs = ['_local', '_stop_event']
# Exclude unpickleable attributes that may contain weakrefs
unpicklable_attrs = [
'_local', # Thread-local storage (contains requests.Session)
'_stop_event', # Threading event
'logger', # Logger may contain weakrefs in handlers
]
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
# Also handle any potential weakrefs in the config object
if 'config' in state and hasattr(state['config'], '__getstate__'):
# If config has its own pickle support, let it handle itself
pass
elif 'config' in state:
# Otherwise, ensure config doesn't contain unpicklable objects
try:
# Test if config can be pickled
import pickle
pickle.dumps(state['config'])
except (TypeError, AttributeError):
# If config can't be pickled, we'll recreate it during unpickling
state['_config_class'] = type(state['config']).__name__
del state['config']
return state
def __setstate__(self, state):
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# Re-initialize the '_local' attribute and stop event
# Re-initialize unpickleable attributes
self._local = threading.local()
self._stop_event = None
self.logger = get_forensic_logger()
# Recreate config if it was removed during pickling
if not hasattr(self, 'config') and hasattr(self, '_config_class'):
if self._config_class == 'Config':
from config import config as global_config
self.config = global_config
elif self._config_class == 'SessionConfig':
from core.session_config import create_session_config
self.config = create_session_config()
del self._config_class
@property
def session(self):
"""Get or create thread-local requests session."""
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({

View File

@ -10,6 +10,7 @@ from core.graph_manager import NodeType, GraphManager
class CorrelationProvider(BaseProvider):
"""
A provider that finds correlations between nodes in the graph.
FIXED: Enhanced pickle support to prevent weakref issues with graph references.
"""
def __init__(self, name: str = "correlation", session_config=None):
@ -26,6 +27,7 @@ class CorrelationProvider(BaseProvider):
'cert_common_name',
'cert_validity_period_days',
'cert_issuer_name',
'cert_serial_number',
'cert_entry_timestamp',
'cert_not_before',
'cert_not_after',
@ -37,6 +39,38 @@ class CorrelationProvider(BaseProvider):
'query_timestamp',
]
def __getstate__(self):
"""
FIXED: Prepare CorrelationProvider for pickling by excluding graph reference.
"""
state = super().__getstate__()
# Remove graph reference to prevent circular dependencies and weakrefs
if 'graph' in state:
del state['graph']
# Also handle correlation_index which might contain complex objects
if 'correlation_index' in state:
# Clear correlation index as it will be rebuilt when needed
state['correlation_index'] = {}
return state
def __setstate__(self, state):
"""
FIXED: Restore CorrelationProvider after unpickling.
"""
super().__setstate__(state)
# Re-initialize graph reference (will be set by scanner)
self.graph = None
# Re-initialize correlation index
self.correlation_index = {}
# Re-compile regex pattern
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
def get_name(self) -> str:
"""Return the provider name."""
return "correlation"
@ -78,13 +112,20 @@ class CorrelationProvider(BaseProvider):
def _find_correlations(self, node_id: str) -> ProviderResult:
"""
Find correlations for a given node.
FIXED: Added safety checks to prevent issues when graph is None.
"""
result = ProviderResult()
# FIXED: Ensure self.graph is not None before proceeding.
# FIXED: Ensure self.graph is not None before proceeding
if not self.graph or not self.graph.graph.has_node(node_id):
return result
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
try:
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
except Exception as e:
# If there's any issue accessing the graph, return empty result
print(f"Warning: Could not access graph for correlation analysis: {e}")
return result
for attr in node_attributes:
attr_name = attr.get('name')
@ -133,6 +174,7 @@ class CorrelationProvider(BaseProvider):
if len(self.correlation_index[attr_value]['nodes']) > 1:
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
return result
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):

View File

@ -2,38 +2,17 @@
import json
import re
import psycopg2
from pathlib import Path
from typing import List, Dict, Any, Set, Optional
from urllib.parse import quote
from datetime import datetime, timezone
import requests
from psycopg2 import pool
from .base_provider import BaseProvider
from core.provider_result import ProviderResult
from utils.helpers import _is_valid_domain
from core.logger import get_forensic_logger
# --- Global Instance for PostgreSQL Connection Pool ---
# This pool will be created once per worker process and is not part of the
# CrtShProvider instance, thus avoiding pickling errors.
db_pool = None
try:
db_pool = psycopg2.pool.SimpleConnectionPool(
1, 5,
host='crt.sh',
port=5432,
user='guest',
dbname='certwatch',
sslmode='prefer',
connect_timeout=60
)
# Use a generic logger here as this is at the module level
get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.")
except Exception as e:
get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.")
class CrtShProvider(BaseProvider):
"""
@ -136,51 +115,42 @@ class CrtShProvider(BaseProvider):
result = ProviderResult()
try:
if cache_status == "fresh":
result = self._load_from_cache(cache_file)
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
if cache_status == "fresh":
result = self._load_from_cache(cache_file)
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
else: # "stale" or "not_found"
# Query the API for the latest certificates
new_raw_certs = self._query_crtsh(domain)
else: # "stale" or "not_found"
# Query the API for the latest certificates
new_raw_certs = self._query_crtsh_api(domain)
if self._stop_event and self._stop_event.is_set():
return ProviderResult()
if self._stop_event and self._stop_event.is_set():
return ProviderResult()
# Combine with old data if cache is stale
if cache_status == "stale":
old_raw_certs = self._load_raw_data_from_cache(cache_file)
combined_certs = old_raw_certs + new_raw_certs
# Combine with old data if cache is stale
if cache_status == "stale":
old_raw_certs = self._load_raw_data_from_cache(cache_file)
combined_certs = old_raw_certs + new_raw_certs
# Deduplicate the combined list
seen_ids = set()
unique_certs = []
for cert in combined_certs:
cert_id = cert.get('id')
if cert_id not in seen_ids:
unique_certs.append(cert)
seen_ids.add(cert_id)
# Deduplicate the combined list
seen_ids = set()
unique_certs = []
for cert in combined_certs:
cert_id = cert.get('id')
if cert_id not in seen_ids:
unique_certs.append(cert)
seen_ids.add(cert_id)
raw_certificates_to_process = unique_certs
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
else: # "not_found"
raw_certificates_to_process = new_raw_certs
raw_certificates_to_process = unique_certs
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
else: # "not_found"
raw_certificates_to_process = new_raw_certs
# FIXED: Process certificates to create proper domain and CA nodes
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
# FIXED: Process certificates to create proper domain and CA nodes
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
# Save the new result and the raw data to the cache
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
except (requests.exceptions.RequestException, psycopg2.Error) as e:
self.logger.logger.error(f"Upstream query failed for {domain}: {e}")
if cache_status != "not_found":
result = self._load_from_cache(cache_file)
self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.")
else:
raise e # Re-raise if there's no cache to fall back on
# Save the new result and the raw data to the cache
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
return result
@ -278,58 +248,6 @@ class CrtShProvider(BaseProvider):
except Exception as e:
self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}")
def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh, trying the database first and falling back to the API."""
global db_pool
if db_pool:
try:
self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}")
return self._query_crtsh_db(domain)
except psycopg2.Error as e:
self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.")
return self._query_crtsh_api(domain)
else:
self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}")
return self._query_crtsh_api(domain)
def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh database for raw certificate data."""
global db_pool
conn = db_pool.getconn()
try:
with conn.cursor() as cursor:
query = """
SELECT
c.id,
x509_serialnumber(c.certificate) as serial_number,
x509_notbefore(c.certificate) as not_before,
x509_notafter(c.certificate) as not_after,
c.issuer_ca_id,
ca.name as issuer_name,
x509_commonname(c.certificate) as common_name,
identities(c.certificate)::text as name_value
FROM certificate c
LEFT JOIN ca ON c.issuer_ca_id = ca.id
WHERE identities(c.certificate) @@ plainto_tsquery(%s)
ORDER BY c.id DESC
LIMIT 5000;
"""
cursor.execute(query, (domain,))
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
row_dict = dict(zip(columns, row))
if row_dict.get('not_before'):
row_dict['not_before'] = row_dict['not_before'].isoformat()
if row_dict.get('not_after'):
row_dict['not_after'] = row_dict['not_after'].isoformat()
results.append(row_dict)
self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.")
return results
finally:
db_pool.putconn(conn)
def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh API for raw certificate data."""
url = f"{self.base_url}?q={quote(domain)}&output=json"

View File

@ -11,6 +11,7 @@ class DNSProvider(BaseProvider):
"""
Provider for standard DNS resolution and reverse DNS lookups.
Now returns standardized ProviderResult objects with IPv4 and IPv6 support.
FIXED: Enhanced pickle support to prevent resolver serialization issues.
"""
def __init__(self, name=None, session_config=None):
@ -27,6 +28,22 @@ class DNSProvider(BaseProvider):
self.resolver.timeout = 5
self.resolver.lifetime = 10
def __getstate__(self):
"""Prepare the object for pickling by excluding resolver."""
state = super().__getstate__()
# Remove the unpickleable 'resolver' attribute
if 'resolver' in state:
del state['resolver']
return state
def __setstate__(self, state):
"""Restore the object after unpickling by reconstructing resolver."""
super().__setstate__(state)
# Re-initialize the 'resolver' attribute
self.resolver = resolver.Resolver()
self.resolver.timeout = 5
self.resolver.lifetime = 10
def get_name(self) -> str:
"""Return the provider name."""
return "dns"
@ -106,10 +123,10 @@ class DNSProvider(BaseProvider):
if _is_valid_domain(hostname):
# Determine appropriate forward relationship type based on IP version
if ip_version == 6:
relationship_type = 'dns_aaaa_record'
relationship_type = 'shodan_aaaa_record'
record_prefix = 'AAAA'
else:
relationship_type = 'dns_a_record'
relationship_type = 'shodan_a_record'
record_prefix = 'A'
# Add the relationship

View File

@ -36,6 +36,15 @@ class ShodanProvider(BaseProvider):
self.cache_dir = Path('cache') / 'shodan'
self.cache_dir.mkdir(parents=True, exist_ok=True)
def __getstate__(self):
"""Prepare the object for pickling."""
state = super().__getstate__()
return state
def __setstate__(self, state):
"""Restore the object after unpickling."""
super().__setstate__(state)
def _check_api_connection(self) -> bool:
"""
FIXED: Lazy connection checking - only test when actually needed.

View File

@ -9,3 +9,5 @@ gunicorn
redis
python-dotenv
psycopg2-binary
Flask-SocketIO
eventlet

View File

@ -1,15 +1,15 @@
/**
* Main application logic for DNSRecon web interface
* Handles UI interactions, API communication, and data flow
* UPDATED: Now compatible with a strictly flat, unified data model for attributes.
* FIXED: Enhanced real-time WebSocket graph updates
*/
class DNSReconApp {
constructor() {
console.log('DNSReconApp constructor called');
this.graphManager = null;
this.socket = null;
this.scanStatus = 'idle';
this.pollInterval = null;
this.currentSessionId = null;
this.elements = {};
@ -17,6 +17,14 @@ class DNSReconApp {
this.isScanning = false;
this.lastGraphUpdate = null;
// FIXED: Add connection state tracking
this.isConnected = false;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 5;
// FIXED: Track last graph data for debugging
this.lastGraphData = null;
this.init();
}
@ -31,13 +39,11 @@ class DNSReconApp {
this.initializeElements();
this.setupEventHandlers();
this.initializeGraph();
this.updateStatus();
this.initializeSocket();
this.loadProviders();
this.initializeEnhancedModals();
this.addCheckboxStyling();
this.updateGraph();
console.log('DNSRecon application initialized successfully');
} catch (error) {
console.error('Failed to initialize DNSRecon application:', error);
@ -46,6 +52,162 @@ class DNSReconApp {
});
}
initializeSocket() {
console.log('🔌 Initializing WebSocket connection...');
try {
this.socket = io({
transports: ['websocket', 'polling'],
timeout: 10000,
reconnection: true,
reconnectionAttempts: 5,
reconnectionDelay: 2000
});
this.socket.on('connect', () => {
console.log('✅ WebSocket connected successfully');
this.isConnected = true;
this.reconnectAttempts = 0;
this.updateConnectionStatus('idle');
console.log('📡 Requesting initial status...');
this.socket.emit('get_status');
});
this.socket.on('disconnect', (reason) => {
console.log('❌ WebSocket disconnected:', reason);
this.isConnected = false;
this.updateConnectionStatus('error');
});
this.socket.on('connect_error', (error) => {
console.error('❌ WebSocket connection error:', error);
this.reconnectAttempts++;
this.updateConnectionStatus('error');
if (this.reconnectAttempts >= 5) {
this.showError('WebSocket connection failed. Please refresh the page.');
}
});
this.socket.on('reconnect', (attemptNumber) => {
console.log('✅ WebSocket reconnected after', attemptNumber, 'attempts');
this.isConnected = true;
this.reconnectAttempts = 0;
this.updateConnectionStatus('idle');
this.socket.emit('get_status');
});
// FIXED: Enhanced scan_update handler with detailed graph processing and debugging
this.socket.on('scan_update', (data) => {
console.log('📨 WebSocket update received:', {
status: data.status,
target: data.target_domain,
progress: data.progress_percentage,
graphNodes: data.graph?.nodes?.length || 0,
graphEdges: data.graph?.edges?.length || 0,
timestamp: new Date().toISOString()
});
try {
// Handle status change
if (data.status !== this.scanStatus) {
console.log(`📄 Status change: ${this.scanStatus}${data.status}`);
this.handleStatusChange(data.status, data.task_queue_size);
}
this.scanStatus = data.status;
// Update status display
this.updateStatusDisplay(data);
// FIXED: Always update graph if data is present and graph manager exists
if (data.graph && this.graphManager) {
console.log('📊 Processing graph update:', {
nodes: data.graph.nodes?.length || 0,
edges: data.graph.edges?.length || 0,
hasNodes: Array.isArray(data.graph.nodes),
hasEdges: Array.isArray(data.graph.edges),
isInitialized: this.graphManager.isInitialized
});
// FIXED: Initialize graph manager if not already done
if (!this.graphManager.isInitialized) {
console.log('🎯 Initializing graph manager...');
this.graphManager.initialize();
}
// FIXED: Force graph update and verify it worked
const previousNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0;
const previousEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0;
console.log('🔄 Before update - Nodes:', previousNodeCount, 'Edges:', previousEdgeCount);
// Store the data for debugging
this.lastGraphData = data.graph;
// Update the graph
this.graphManager.updateGraph(data.graph);
this.lastGraphUpdate = Date.now();
// Verify the update worked
const newNodeCount = this.graphManager.nodes ? this.graphManager.nodes.length : 0;
const newEdgeCount = this.graphManager.edges ? this.graphManager.edges.length : 0;
console.log('🔄 After update - Nodes:', newNodeCount, 'Edges:', newEdgeCount);
if (newNodeCount !== data.graph.nodes.length || newEdgeCount !== data.graph.edges.length) {
console.warn('⚠️ Graph update mismatch!', {
expectedNodes: data.graph.nodes.length,
actualNodes: newNodeCount,
expectedEdges: data.graph.edges.length,
actualEdges: newEdgeCount
});
// Force a complete rebuild if there's a mismatch
console.log('🔧 Force rebuilding graph...');
this.graphManager.clear();
this.graphManager.updateGraph(data.graph);
}
console.log('✅ Graph updated successfully');
// FIXED: Force network redraw if we're using vis.js
if (this.graphManager.network) {
try {
this.graphManager.network.redraw();
console.log('🎨 Network redrawn');
} catch (redrawError) {
console.warn('⚠️ Network redraw failed:', redrawError);
}
}
} else {
if (!data.graph) {
console.log('⚠️ No graph data in WebSocket update');
}
if (!this.graphManager) {
console.log('⚠️ Graph manager not available');
}
}
} catch (error) {
console.error('❌ Error processing WebSocket update:', error);
console.error('Update data:', data);
console.error('Stack trace:', error.stack);
}
});
this.socket.on('error', (error) => {
console.error('❌ WebSocket error:', error);
this.showError('WebSocket communication error');
});
} catch (error) {
console.error('❌ Failed to initialize WebSocket:', error);
this.showError('Failed to establish real-time connection');
}
}
/**
* Initialize DOM element references
*/
@ -263,12 +425,36 @@ class DNSReconApp {
}
/**
* Initialize graph visualization
* FIXED: Initialize graph visualization with enhanced debugging
*/
initializeGraph() {
try {
console.log('Initializing graph manager...');
this.graphManager = new GraphManager('network-graph');
// FIXED: Add debugging hooks to graph manager
if (this.graphManager) {
// Override updateGraph to add debugging
const originalUpdateGraph = this.graphManager.updateGraph.bind(this.graphManager);
this.graphManager.updateGraph = (graphData) => {
console.log('🔧 GraphManager.updateGraph called with:', {
nodes: graphData?.nodes?.length || 0,
edges: graphData?.edges?.length || 0,
timestamp: new Date().toISOString()
});
const result = originalUpdateGraph(graphData);
console.log('🔧 GraphManager.updateGraph completed, network state:', {
networkExists: !!this.graphManager.network,
nodeDataSetLength: this.graphManager.nodes?.length || 0,
edgeDataSetLength: this.graphManager.edges?.length || 0
});
return result;
};
}
console.log('Graph manager initialized successfully');
} catch (error) {
console.error('Failed to initialize graph manager:', error);
@ -288,7 +474,6 @@ class DNSReconApp {
console.log(`Target: "${target}", Max depth: ${maxDepth}`);
// Validation
if (!target) {
console.log('Validation failed: empty target');
this.showError('Please enter a target domain or IP');
@ -303,6 +488,19 @@ class DNSReconApp {
return;
}
// FIXED: Ensure WebSocket connection before starting scan
if (!this.isConnected) {
console.log('WebSocket not connected, attempting to connect...');
this.socket.connect();
// Wait a moment for connection
await new Promise(resolve => setTimeout(resolve, 1000));
if (!this.isConnected) {
this.showWarning('WebSocket connection not established. Updates may be delayed.');
}
}
console.log('Validation passed, setting UI state to scanning...');
this.setUIState('scanning');
this.showInfo('Starting reconnaissance scan...');
@ -320,23 +518,28 @@ class DNSReconApp {
if (response.success) {
this.currentSessionId = response.scan_id;
this.showSuccess('Reconnaissance scan started successfully');
this.showSuccess('Reconnaissance scan started - watching for real-time updates');
if (clearGraph) {
if (clearGraph && this.graphManager) {
console.log('🧹 Clearing graph for new scan');
this.graphManager.clear();
}
console.log(`Scan started for ${target} with depth ${maxDepth}`);
console.log(`Scan started for ${target} with depth ${maxDepth}`);
// Start polling immediately with faster interval for responsiveness
this.startPolling(1000);
// FIXED: Immediately start listening for updates
if (this.socket && this.isConnected) {
console.log('📡 Requesting initial status update...');
this.socket.emit('get_status');
// Force an immediate status update
console.log('Forcing immediate status update...');
setTimeout(() => {
this.updateStatus();
this.updateGraph();
}, 100);
// Set up periodic status requests as backup (every 5 seconds during scan)
/*this.statusRequestInterval = setInterval(() => {
if (this.isScanning && this.socket && this.isConnected) {
console.log('📡 Periodic status request...');
this.socket.emit('get_status');
}
}, 5000);*/
}
} else {
throw new Error(response.error || 'Failed to start scan');
@ -348,20 +551,23 @@ class DNSReconApp {
this.setUIState('idle');
}
}
/**
* Scan stop with immediate UI feedback
*/
// FIXED: Enhanced stop scan with interval cleanup
async stopScan() {
try {
console.log('Stopping scan...');
// Immediately disable stop button and show stopping state
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
if (this.elements.stopScan) {
this.elements.stopScan.disabled = true;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOPPING]</span><span>Stopping...</span>';
}
// Show immediate feedback
this.showInfo('Stopping scan...');
const response = await this.apiCall('/api/scan/stop', 'POST');
@ -369,21 +575,10 @@ class DNSReconApp {
if (response.success) {
this.showSuccess('Scan stop requested');
// Force immediate status update
setTimeout(() => {
this.updateStatus();
}, 100);
// Continue polling for a bit to catch the status change
this.startPolling(500); // Fast polling to catch status change
// Stop fast polling after 10 seconds
setTimeout(() => {
if (this.scanStatus === 'stopped' || this.scanStatus === 'idle') {
this.stopPolling();
}
}, 10000);
// Request final status update
if (this.socket && this.isConnected) {
setTimeout(() => this.socket.emit('get_status'), 500);
}
} else {
throw new Error(response.error || 'Failed to stop scan');
}
@ -392,7 +587,6 @@ class DNSReconApp {
console.error('Failed to stop scan:', error);
this.showError(`Failed to stop scan: ${error.message}`);
// Re-enable stop button on error
if (this.elements.stopScan) {
this.elements.stopScan.disabled = false;
this.elements.stopScan.innerHTML = '<span class="btn-icon">[STOP]</span><span>Terminate Scan</span>';
@ -549,85 +743,24 @@ class DNSReconApp {
}
/**
* Start polling for scan updates with configurable interval
*/
startPolling(interval = 2000) {
console.log('=== STARTING POLLING ===');
if (this.pollInterval) {
console.log('Clearing existing poll interval');
clearInterval(this.pollInterval);
}
this.pollInterval = setInterval(() => {
this.updateStatus();
this.updateGraph();
this.loadProviders();
}, interval);
console.log(`Polling started with ${interval}ms interval`);
}
/**
* Stop polling for updates
*/
stopPolling() {
console.log('=== STOPPING POLLING ===');
if (this.pollInterval) {
clearInterval(this.pollInterval);
this.pollInterval = null;
}
}
/**
* Status update with better error handling
*/
async updateStatus() {
try {
const response = await this.apiCall('/api/scan/status');
if (response.success && response.status) {
const status = response.status;
this.updateStatusDisplay(status);
// Handle status changes
if (status.status !== this.scanStatus) {
console.log(`*** STATUS CHANGED: ${this.scanStatus} -> ${status.status} ***`);
this.handleStatusChange(status.status, status.task_queue_size);
}
this.scanStatus = status.status;
} else {
console.error('Status update failed:', response);
// Don't show error for status updates to avoid spam
}
} catch (error) {
console.error('Failed to update status:', error);
this.showConnectionError();
}
}
/**
* Update graph from server
* FIXED: Update graph from server with enhanced debugging
*/
async updateGraph() {
try {
console.log('Updating graph...');
console.log('Updating graph via API call...');
const response = await this.apiCall('/api/graph');
if (response.success) {
const graphData = response.graph;
console.log('Graph data received:');
console.log('Graph data received from API:');
console.log('- Nodes:', graphData.nodes ? graphData.nodes.length : 0);
console.log('- Edges:', graphData.edges ? graphData.edges.length : 0);
// FIXED: Always update graph, even if empty - let GraphManager handle placeholder
if (this.graphManager) {
console.log('🔧 Calling GraphManager.updateGraph from API response...');
this.graphManager.updateGraph(graphData);
this.lastGraphUpdate = Date.now();
@ -636,6 +769,8 @@ class DNSReconApp {
if (this.elements.relationshipsDisplay) {
this.elements.relationshipsDisplay.textContent = edgeCount;
}
console.log('✅ Manual graph update completed');
}
} else {
console.error('Graph update failed:', response);
@ -731,48 +866,70 @@ class DNSReconApp {
* @param {string} newStatus - New scan status
*/
handleStatusChange(newStatus, task_queue_size) {
console.log(`=== STATUS CHANGE: ${this.scanStatus} -> ${newStatus} ===`);
console.log(`📄 Status change handler: ${this.scanStatus}${newStatus}`);
switch (newStatus) {
case 'running':
this.setUIState('scanning', task_queue_size);
this.showSuccess('Scan is running');
// Increase polling frequency for active scans
this.startPolling(1000); // Poll every 1 second for running scans
this.showSuccess('Scan is running - updates in real-time');
this.updateConnectionStatus('active');
break;
case 'completed':
this.setUIState('completed', task_queue_size);
this.stopPolling();
this.showSuccess('Scan completed successfully');
this.updateConnectionStatus('completed');
this.loadProviders();
// Force a final graph update
console.log('Scan completed - forcing final graph update');
setTimeout(() => this.updateGraph(), 100);
console.log('✅ Scan completed - requesting final graph update');
// Request final status to ensure we have the complete graph
setTimeout(() => {
if (this.socket && this.isConnected) {
this.socket.emit('get_status');
}
}, 1000);
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
break;
case 'failed':
this.setUIState('failed', task_queue_size);
this.stopPolling();
this.showError('Scan failed');
this.updateConnectionStatus('error');
this.loadProviders();
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
break;
case 'stopped':
this.setUIState('stopped', task_queue_size);
this.stopPolling();
this.showSuccess('Scan stopped');
this.updateConnectionStatus('stopped');
this.loadProviders();
// Clear status request interval
if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}
break;
case 'idle':
this.setUIState('idle', task_queue_size);
this.stopPolling();
this.updateConnectionStatus('idle');
// Clear status request interval
/*if (this.statusRequestInterval) {
clearInterval(this.statusRequestInterval);
this.statusRequestInterval = null;
}*/
break;
default:
@ -824,6 +981,7 @@ class DNSReconApp {
if (this.graphManager) {
this.graphManager.isScanning = true;
}
if (this.elements.startScan) {
this.elements.startScan.disabled = true;
this.elements.startScan.classList.add('loading');
@ -851,6 +1009,7 @@ class DNSReconApp {
if (this.graphManager) {
this.graphManager.isScanning = false;
}
if (this.elements.startScan) {
this.elements.startScan.disabled = !isQueueEmpty;
this.elements.startScan.classList.remove('loading');
@ -1093,7 +1252,7 @@ class DNSReconApp {
} else {
// API key not configured - ALWAYS show input field
const statusClass = info.enabled ? 'enabled' : 'api-key-required';
const statusText = info.enabled ? ' Ready for API Key' : '⚠️ API Key Required';
const statusText = info.enabled ? ' Ready for API Key' : '⚠️ API Key Required';
inputGroup.innerHTML = `
<div class="provider-header">
@ -2033,10 +2192,10 @@ class DNSReconApp {
// If the scanner was idle, it's now running. Start polling to see the new node appear.
if (this.scanStatus === 'idle') {
this.startPolling(1000);
this.socket.emit('get_status');
} else {
// If already scanning, force a quick graph update to see the change sooner.
setTimeout(() => this.updateGraph(), 500);
setTimeout(() => this.socket.emit('get_status'), 500);
}
} else {
@ -2075,8 +2234,8 @@ class DNSReconApp {
*/
getNodeTypeIcon(nodeType) {
const icons = {
'domain': '🌍',
'ip': '📍',
'domain': '🌐',
'ip': '🔢',
'asn': '🏢',
'large_entity': '📦',
'correlation_object': '🔗'

View File

@ -7,6 +7,7 @@
<title>DNSRecon - Infrastructure Reconnaissance</title>
<link rel="stylesheet" href="{{ url_for('static', filename='css/main.css') }}">
<script src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.21.0/vis.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.js"></script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.21.0/vis.min.css" rel="stylesheet" type="text/css">
<link
href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@300;400;500;700&family=Special+Elite&display=swap"