try to implement websockets

This commit is contained in:
overcuriousity 2025-09-20 14:17:17 +02:00
parent 3ee23c9d05
commit 75a595c9cb
11 changed files with 116 additions and 251 deletions

48
app.py
View File

@ -5,7 +5,6 @@ Flask application entry point for DNSRecon web interface.
Provides REST API endpoints and serves the web interface with user session support.
"""
import json
import traceback
from flask import Flask, render_template, request, jsonify, send_file, session
from datetime import datetime, timezone, timedelta
@ -13,6 +12,7 @@ import io
import os
from core.session_manager import session_manager
from flask_socketio import SocketIO
from config import config
from core.graph_manager import NodeType
from utils.helpers import is_valid_target
@ -21,6 +21,7 @@ from decimal import Decimal
app = Flask(__name__)
socketio = SocketIO(app)
app.config['SECRET_KEY'] = config.flask_secret_key
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours)
@ -35,7 +36,7 @@ def get_user_scanner():
if existing_scanner:
return current_flask_session_id, existing_scanner
new_session_id = session_manager.create_session()
new_session_id = session_manager.create_session(socketio)
new_scanner = session_manager.get_session(new_session_id)
if not new_scanner:
@ -127,37 +128,31 @@ def stop_scan():
return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/scan/status', methods=['GET'])
@socketio.on('get_status')
def get_scan_status():
"""Get current scan status."""
try:
user_session_id, scanner = get_user_scanner()
if not scanner:
return jsonify({
'success': True,
'status': {
'status': 'idle', 'target_domain': None, 'current_depth': 0,
'max_depth': 0, 'progress_percentage': 0.0,
'user_session_id': user_session_id
}
})
status = {
'status': 'idle', 'target_domain': None, 'current_depth': 0,
'max_depth': 0, 'progress_percentage': 0.0,
'user_session_id': user_session_id
}
else:
if not scanner.session_id:
scanner.session_id = user_session_id
status = scanner.get_scan_status()
status['user_session_id'] = user_session_id
if not scanner.session_id:
scanner.session_id = user_session_id
status = scanner.get_scan_status()
status['user_session_id'] = user_session_id
return jsonify({'success': True, 'status': status})
socketio.emit('scan_update', status)
except Exception as e:
traceback.print_exc()
return jsonify({
'success': False, 'error': f'Internal server error: {str(e)}',
'fallback_status': {'status': 'error', 'progress_percentage': 0.0}
}), 500
socketio.emit('scan_update', {
'status': 'error', 'message': 'Failed to get status'
})
@app.route('/api/graph', methods=['GET'])
@ -542,9 +537,4 @@ def internal_error(error):
if __name__ == '__main__':
config.load_from_env()
app.run(
host=config.flask_host,
port=config.flask_port,
debug=config.flask_debug,
threaded=True
)
socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug)

View File

@ -6,7 +6,6 @@ import os
import importlib
import redis
import time
import math
import random # Imported for jitter
from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor
@ -38,13 +37,14 @@ class Scanner:
UNIFIED: Combines comprehensive features with improved display formatting.
"""
def __init__(self, session_config=None):
def __init__(self, session_config=None, socketio=None):
"""Initialize scanner with session-specific configuration."""
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
self.socketio = socketio
self.config = session_config
self.graph = GraphManager()
@ -143,7 +143,8 @@ class Scanner:
'rate_limiter',
'logger',
'status_logger_thread',
'status_logger_stop_event'
'status_logger_stop_event',
'socketio'
]
for attr in unpicklable_attrs:
@ -170,6 +171,7 @@ class Scanner:
self.logger = get_forensic_logger()
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
self.socketio = None
if not hasattr(self, 'providers') or not self.providers:
self._initialize_providers()
@ -1024,6 +1026,8 @@ class Scanner:
"""
if self.session_id:
try:
if self.socketio:
self.socketio.emit('scan_update', self.get_scan_status())
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception:
@ -1048,7 +1052,7 @@ class Scanner:
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics(),
'graph': self.get_graph_data(),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],

View File

@ -64,7 +64,7 @@ class SessionManager:
"""Generates the Redis key for a session's stop signal."""
return f"dnsrecon:stop:{session_id}"
def create_session(self) -> str:
def create_session(self, socketio=None) -> str:
"""
FIXED: Create a new user session with thread-safe creation to prevent duplicates.
"""
@ -76,7 +76,7 @@ class SessionManager:
try:
from core.session_config import create_session_config
session_config = create_session_config()
scanner_instance = Scanner(session_config=session_config)
scanner_instance = Scanner(session_config=session_config, socketio=socketio)
# Set the session ID on the scanner for cross-process stop signal management
scanner_instance.session_id = session_id

View File

@ -53,7 +53,7 @@ class BaseProvider(ABC):
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
# Exclude the unpickleable '_local' attribute (which holds the session) and stop event
unpicklable_attrs = ['_local', '_stop_event']
for attr in unpicklable_attrs:
if attr in state:

View File

@ -26,6 +26,7 @@ class CorrelationProvider(BaseProvider):
'cert_common_name',
'cert_validity_period_days',
'cert_issuer_name',
'cert_serial_number',
'cert_entry_timestamp',
'cert_not_before',
'cert_not_after',

View File

@ -2,38 +2,17 @@
import json
import re
import psycopg2
from pathlib import Path
from typing import List, Dict, Any, Set, Optional
from urllib.parse import quote
from datetime import datetime, timezone
import requests
from psycopg2 import pool
from .base_provider import BaseProvider
from core.provider_result import ProviderResult
from utils.helpers import _is_valid_domain
from core.logger import get_forensic_logger
# --- Global Instance for PostgreSQL Connection Pool ---
# This pool will be created once per worker process and is not part of the
# CrtShProvider instance, thus avoiding pickling errors.
db_pool = None
try:
db_pool = psycopg2.pool.SimpleConnectionPool(
1, 5,
host='crt.sh',
port=5432,
user='guest',
dbname='certwatch',
sslmode='prefer',
connect_timeout=60
)
# Use a generic logger here as this is at the module level
get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.")
except Exception as e:
get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.")
class CrtShProvider(BaseProvider):
"""
@ -136,51 +115,42 @@ class CrtShProvider(BaseProvider):
result = ProviderResult()
try:
if cache_status == "fresh":
result = self._load_from_cache(cache_file)
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
if cache_status == "fresh":
result = self._load_from_cache(cache_file)
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
else: # "stale" or "not_found"
# Query the API for the latest certificates
new_raw_certs = self._query_crtsh(domain)
else: # "stale" or "not_found"
# Query the API for the latest certificates
new_raw_certs = self._query_crtsh_api(domain)
if self._stop_event and self._stop_event.is_set():
return ProviderResult()
if self._stop_event and self._stop_event.is_set():
return ProviderResult()
# Combine with old data if cache is stale
if cache_status == "stale":
old_raw_certs = self._load_raw_data_from_cache(cache_file)
combined_certs = old_raw_certs + new_raw_certs
# Combine with old data if cache is stale
if cache_status == "stale":
old_raw_certs = self._load_raw_data_from_cache(cache_file)
combined_certs = old_raw_certs + new_raw_certs
# Deduplicate the combined list
seen_ids = set()
unique_certs = []
for cert in combined_certs:
cert_id = cert.get('id')
if cert_id not in seen_ids:
unique_certs.append(cert)
seen_ids.add(cert_id)
# Deduplicate the combined list
seen_ids = set()
unique_certs = []
for cert in combined_certs:
cert_id = cert.get('id')
if cert_id not in seen_ids:
unique_certs.append(cert)
seen_ids.add(cert_id)
raw_certificates_to_process = unique_certs
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
else: # "not_found"
raw_certificates_to_process = new_raw_certs
raw_certificates_to_process = unique_certs
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
else: # "not_found"
raw_certificates_to_process = new_raw_certs
# FIXED: Process certificates to create proper domain and CA nodes
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
# FIXED: Process certificates to create proper domain and CA nodes
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
# Save the new result and the raw data to the cache
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
except (requests.exceptions.RequestException, psycopg2.Error) as e:
self.logger.logger.error(f"Upstream query failed for {domain}: {e}")
if cache_status != "not_found":
result = self._load_from_cache(cache_file)
self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.")
else:
raise e # Re-raise if there's no cache to fall back on
# Save the new result and the raw data to the cache
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
return result
@ -278,58 +248,6 @@ class CrtShProvider(BaseProvider):
except Exception as e:
self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}")
def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh, trying the database first and falling back to the API."""
global db_pool
if db_pool:
try:
self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}")
return self._query_crtsh_db(domain)
except psycopg2.Error as e:
self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.")
return self._query_crtsh_api(domain)
else:
self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}")
return self._query_crtsh_api(domain)
def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh database for raw certificate data."""
global db_pool
conn = db_pool.getconn()
try:
with conn.cursor() as cursor:
query = """
SELECT
c.id,
x509_serialnumber(c.certificate) as serial_number,
x509_notbefore(c.certificate) as not_before,
x509_notafter(c.certificate) as not_after,
c.issuer_ca_id,
ca.name as issuer_name,
x509_commonname(c.certificate) as common_name,
identities(c.certificate)::text as name_value
FROM certificate c
LEFT JOIN ca ON c.issuer_ca_id = ca.id
WHERE identities(c.certificate) @@ plainto_tsquery(%s)
ORDER BY c.id DESC
LIMIT 5000;
"""
cursor.execute(query, (domain,))
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
row_dict = dict(zip(columns, row))
if row_dict.get('not_before'):
row_dict['not_before'] = row_dict['not_before'].isoformat()
if row_dict.get('not_after'):
row_dict['not_after'] = row_dict['not_after'].isoformat()
results.append(row_dict)
self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.")
return results
finally:
db_pool.putconn(conn)
def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh API for raw certificate data."""
url = f"{self.base_url}?q={quote(domain)}&output=json"

View File

@ -27,6 +27,21 @@ class DNSProvider(BaseProvider):
self.resolver.timeout = 5
self.resolver.lifetime = 10
def __getstate__(self):
"""Prepare the object for pickling."""
state = self.__dict__.copy()
# Remove the unpickleable 'resolver' attribute
if 'resolver' in state:
del state['resolver']
return state
def __setstate__(self, state):
"""Restore the object after unpickling."""
self.__dict__.update(state)
# Re-initialize the 'resolver' attribute
self.resolver = resolver.Resolver()
self.resolver.timeout = 5
def get_name(self) -> str:
"""Return the provider name."""
return "dns"

View File

@ -36,6 +36,15 @@ class ShodanProvider(BaseProvider):
self.cache_dir = Path('cache') / 'shodan'
self.cache_dir.mkdir(parents=True, exist_ok=True)
def __getstate__(self):
"""Prepare the object for pickling."""
state = super().__getstate__()
return state
def __setstate__(self, state):
"""Restore the object after unpickling."""
super().__setstate__(state)
def _check_api_connection(self) -> bool:
"""
FIXED: Lazy connection checking - only test when actually needed.

View File

@ -9,3 +9,5 @@ gunicorn
redis
python-dotenv
psycopg2-binary
Flask-SocketIO
eventlet

View File

@ -8,8 +8,8 @@ class DNSReconApp {
constructor() {
console.log('DNSReconApp constructor called');
this.graphManager = null;
this.socket = null;
this.scanStatus = 'idle';
this.pollInterval = null;
this.currentSessionId = null;
this.elements = {};
@ -31,13 +31,11 @@ class DNSReconApp {
this.initializeElements();
this.setupEventHandlers();
this.initializeGraph();
this.updateStatus();
this.initializeSocket();
this.loadProviders();
this.initializeEnhancedModals();
this.addCheckboxStyling();
this.updateGraph();
console.log('DNSRecon application initialized successfully');
} catch (error) {
console.error('Failed to initialize DNSRecon application:', error);
@ -46,6 +44,25 @@ class DNSReconApp {
});
}
initializeSocket() {
this.socket = io();
this.socket.on('connect', () => {
console.log('Connected to WebSocket server');
this.updateConnectionStatus('idle');
this.socket.emit('get_status');
});
this.socket.on('scan_update', (data) => {
if (data.status !== this.scanStatus) {
this.handleStatusChange(data.status, data.task_queue_size);
}
this.scanStatus = data.status;
this.updateStatusDisplay(data);
this.graphManager.updateGraph(data.graph);
});
}
/**
* Initialize DOM element references
*/
@ -328,15 +345,8 @@ class DNSReconApp {
console.log(`Scan started for ${target} with depth ${maxDepth}`);
// Start polling immediately with faster interval for responsiveness
this.startPolling(1000);
// Force an immediate status update
console.log('Forcing immediate status update...');
setTimeout(() => {
this.updateStatus();
this.updateGraph();
}, 100);
// Request initial status update via WebSocket
this.socket.emit('get_status');
} else {
throw new Error(response.error || 'Failed to start scan');
@ -368,22 +378,6 @@ class DNSReconApp {
if (response.success) {
this.showSuccess('Scan stop requested');
// Force immediate status update
setTimeout(() => {
this.updateStatus();
}, 100);
// Continue polling for a bit to catch the status change
this.startPolling(500); // Fast polling to catch status change
// Stop fast polling after 10 seconds
setTimeout(() => {
if (this.scanStatus === 'stopped' || this.scanStatus === 'idle') {
this.stopPolling();
}
}, 10000);
} else {
throw new Error(response.error || 'Failed to stop scan');
}
@ -548,68 +542,6 @@ class DNSReconApp {
}
}
/**
* Start polling for scan updates with configurable interval
*/
startPolling(interval = 2000) {
console.log('=== STARTING POLLING ===');
if (this.pollInterval) {
console.log('Clearing existing poll interval');
clearInterval(this.pollInterval);
}
this.pollInterval = setInterval(() => {
this.updateStatus();
this.updateGraph();
this.loadProviders();
}, interval);
console.log(`Polling started with ${interval}ms interval`);
}
/**
* Stop polling for updates
*/
stopPolling() {
console.log('=== STOPPING POLLING ===');
if (this.pollInterval) {
clearInterval(this.pollInterval);
this.pollInterval = null;
}
}
/**
* Status update with better error handling
*/
async updateStatus() {
try {
const response = await this.apiCall('/api/scan/status');
if (response.success && response.status) {
const status = response.status;
this.updateStatusDisplay(status);
// Handle status changes
if (status.status !== this.scanStatus) {
console.log(`*** STATUS CHANGED: ${this.scanStatus} -> ${status.status} ***`);
this.handleStatusChange(status.status, status.task_queue_size);
}
this.scanStatus = status.status;
} else {
console.error('Status update failed:', response);
// Don't show error for status updates to avoid spam
}
} catch (error) {
console.error('Failed to update status:', error);
this.showConnectionError();
}
}
/**
* Update graph from server
*/
@ -737,25 +669,20 @@ class DNSReconApp {
case 'running':
this.setUIState('scanning', task_queue_size);
this.showSuccess('Scan is running');
// Increase polling frequency for active scans
this.startPolling(1000); // Poll every 1 second for running scans
this.updateConnectionStatus('active');
break;
case 'completed':
this.setUIState('completed', task_queue_size);
this.stopPolling();
this.showSuccess('Scan completed successfully');
this.updateConnectionStatus('completed');
this.loadProviders();
// Force a final graph update
console.log('Scan completed - forcing final graph update');
setTimeout(() => this.updateGraph(), 100);
break;
case 'failed':
this.setUIState('failed', task_queue_size);
this.stopPolling();
this.showError('Scan failed');
this.updateConnectionStatus('error');
this.loadProviders();
@ -763,7 +690,6 @@ class DNSReconApp {
case 'stopped':
this.setUIState('stopped', task_queue_size);
this.stopPolling();
this.showSuccess('Scan stopped');
this.updateConnectionStatus('stopped');
this.loadProviders();
@ -771,7 +697,6 @@ class DNSReconApp {
case 'idle':
this.setUIState('idle', task_queue_size);
this.stopPolling();
this.updateConnectionStatus('idle');
break;
@ -2033,10 +1958,10 @@ class DNSReconApp {
// If the scanner was idle, it's now running. Start polling to see the new node appear.
if (this.scanStatus === 'idle') {
this.startPolling(1000);
this.socket.emit('get_status');
} else {
// If already scanning, force a quick graph update to see the change sooner.
setTimeout(() => this.updateGraph(), 500);
setTimeout(() => this.socket.emit('get_status'), 500);
}
} else {

View File

@ -7,6 +7,7 @@
<title>DNSRecon - Infrastructure Reconnaissance</title>
<link rel="stylesheet" href="{{ url_for('static', filename='css/main.css') }}">
<script src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.21.0/vis.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.7.2/socket.io.js"></script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.21.0/vis.min.css" rel="stylesheet" type="text/css">
<link
href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@300;400;500;700&family=Special+Elite&display=swap"