diff --git a/app.py b/app.py index 1dd1693..bc1f29d 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,6 @@ Flask application entry point for DNSRecon web interface. Provides REST API endpoints and serves the web interface with user session support. """ -import json import traceback from flask import Flask, render_template, request, jsonify, send_file, session from datetime import datetime, timezone, timedelta @@ -13,6 +12,7 @@ import io import os from core.session_manager import session_manager +from flask_socketio import SocketIO from config import config from core.graph_manager import NodeType from utils.helpers import is_valid_target @@ -21,6 +21,7 @@ from decimal import Decimal app = Flask(__name__) +socketio = SocketIO(app) app.config['SECRET_KEY'] = config.flask_secret_key app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=config.flask_permanent_session_lifetime_hours) @@ -35,7 +36,7 @@ def get_user_scanner(): if existing_scanner: return current_flask_session_id, existing_scanner - new_session_id = session_manager.create_session() + new_session_id = session_manager.create_session(socketio) new_scanner = session_manager.get_session(new_session_id) if not new_scanner: @@ -127,37 +128,31 @@ def stop_scan(): return jsonify({'success': False, 'error': f'Internal server error: {str(e)}'}), 500 -@app.route('/api/scan/status', methods=['GET']) +@socketio.on('get_status') def get_scan_status(): """Get current scan status.""" try: user_session_id, scanner = get_user_scanner() if not scanner: - return jsonify({ - 'success': True, - 'status': { - 'status': 'idle', 'target_domain': None, 'current_depth': 0, - 'max_depth': 0, 'progress_percentage': 0.0, - 'user_session_id': user_session_id - } - }) + status = { + 'status': 'idle', 'target_domain': None, 'current_depth': 0, + 'max_depth': 0, 'progress_percentage': 0.0, + 'user_session_id': user_session_id + } + else: + if not scanner.session_id: + scanner.session_id = user_session_id + status = scanner.get_scan_status() + status['user_session_id'] = user_session_id - if not scanner.session_id: - scanner.session_id = user_session_id - - status = scanner.get_scan_status() - status['user_session_id'] = user_session_id - - return jsonify({'success': True, 'status': status}) + socketio.emit('scan_update', status) except Exception as e: traceback.print_exc() - return jsonify({ - 'success': False, 'error': f'Internal server error: {str(e)}', - 'fallback_status': {'status': 'error', 'progress_percentage': 0.0} - }), 500 - + socketio.emit('scan_update', { + 'status': 'error', 'message': 'Failed to get status' + }) @app.route('/api/graph', methods=['GET']) @@ -542,9 +537,4 @@ def internal_error(error): if __name__ == '__main__': config.load_from_env() - app.run( - host=config.flask_host, - port=config.flask_port, - debug=config.flask_debug, - threaded=True - ) \ No newline at end of file + socketio.run(app, host=config.flask_host, port=config.flask_port, debug=config.flask_debug) \ No newline at end of file diff --git a/core/scanner.py b/core/scanner.py index 728f602..01a1889 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -6,7 +6,6 @@ import os import importlib import redis import time -import math import random # Imported for jitter from typing import List, Set, Dict, Any, Tuple, Optional from concurrent.futures import ThreadPoolExecutor @@ -38,13 +37,14 @@ class Scanner: UNIFIED: Combines comprehensive features with improved display formatting. """ - def __init__(self, session_config=None): + def __init__(self, session_config=None, socketio=None): """Initialize scanner with session-specific configuration.""" try: # Use provided session config or create default if session_config is None: from core.session_config import create_session_config session_config = create_session_config() + self.socketio = socketio self.config = session_config self.graph = GraphManager() @@ -143,7 +143,8 @@ class Scanner: 'rate_limiter', 'logger', 'status_logger_thread', - 'status_logger_stop_event' + 'status_logger_stop_event', + 'socketio' ] for attr in unpicklable_attrs: @@ -170,6 +171,7 @@ class Scanner: self.logger = get_forensic_logger() self.status_logger_thread = None self.status_logger_stop_event = threading.Event() + self.socketio = None if not hasattr(self, 'providers') or not self.providers: self._initialize_providers() @@ -1024,6 +1026,8 @@ class Scanner: """ if self.session_id: try: + if self.socketio: + self.socketio.emit('scan_update', self.get_scan_status()) from core.session_manager import session_manager session_manager.update_session_scanner(self.session_id, self) except Exception: @@ -1048,7 +1052,7 @@ class Scanner: 'progress_percentage': self._calculate_progress(), 'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued, 'enabled_providers': [provider.get_name() for provider in self.providers], - 'graph_statistics': self.graph.get_statistics(), + 'graph': self.get_graph_data(), 'task_queue_size': self.task_queue.qsize(), 'currently_processing_count': currently_processing_count, 'currently_processing': currently_processing_list[:5], diff --git a/core/session_manager.py b/core/session_manager.py index a1d916c..14d0c9a 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -64,7 +64,7 @@ class SessionManager: """Generates the Redis key for a session's stop signal.""" return f"dnsrecon:stop:{session_id}" - def create_session(self) -> str: + def create_session(self, socketio=None) -> str: """ FIXED: Create a new user session with thread-safe creation to prevent duplicates. """ @@ -76,7 +76,7 @@ class SessionManager: try: from core.session_config import create_session_config session_config = create_session_config() - scanner_instance = Scanner(session_config=session_config) + scanner_instance = Scanner(session_config=session_config, socketio=socketio) # Set the session ID on the scanner for cross-process stop signal management scanner_instance.session_id = session_id diff --git a/providers/base_provider.py b/providers/base_provider.py index d326def..61a3bbe 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -53,7 +53,7 @@ class BaseProvider(ABC): def __getstate__(self): """Prepare BaseProvider for pickling by excluding unpicklable objects.""" state = self.__dict__.copy() - # Exclude the unpickleable '_local' attribute and stop event + # Exclude the unpickleable '_local' attribute (which holds the session) and stop event unpicklable_attrs = ['_local', '_stop_event'] for attr in unpicklable_attrs: if attr in state: diff --git a/providers/correlation_provider.py b/providers/correlation_provider.py index 35bb3d4..64ed854 100644 --- a/providers/correlation_provider.py +++ b/providers/correlation_provider.py @@ -26,6 +26,7 @@ class CorrelationProvider(BaseProvider): 'cert_common_name', 'cert_validity_period_days', 'cert_issuer_name', + 'cert_serial_number', 'cert_entry_timestamp', 'cert_not_before', 'cert_not_after', diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index c946c74..dea89e2 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -2,38 +2,17 @@ import json import re -import psycopg2 from pathlib import Path from typing import List, Dict, Any, Set, Optional from urllib.parse import quote from datetime import datetime, timezone import requests -from psycopg2 import pool from .base_provider import BaseProvider from core.provider_result import ProviderResult from utils.helpers import _is_valid_domain from core.logger import get_forensic_logger -# --- Global Instance for PostgreSQL Connection Pool --- -# This pool will be created once per worker process and is not part of the -# CrtShProvider instance, thus avoiding pickling errors. -db_pool = None -try: - db_pool = psycopg2.pool.SimpleConnectionPool( - 1, 5, - host='crt.sh', - port=5432, - user='guest', - dbname='certwatch', - sslmode='prefer', - connect_timeout=60 - ) - # Use a generic logger here as this is at the module level - get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.") -except Exception as e: - get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.") - class CrtShProvider(BaseProvider): """ @@ -136,51 +115,42 @@ class CrtShProvider(BaseProvider): result = ProviderResult() - try: - if cache_status == "fresh": - result = self._load_from_cache(cache_file) - self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") + if cache_status == "fresh": + result = self._load_from_cache(cache_file) + self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") + + else: # "stale" or "not_found" + # Query the API for the latest certificates + new_raw_certs = self._query_crtsh_api(domain) - else: # "stale" or "not_found" - # Query the API for the latest certificates - new_raw_certs = self._query_crtsh(domain) + if self._stop_event and self._stop_event.is_set(): + return ProviderResult() + + # Combine with old data if cache is stale + if cache_status == "stale": + old_raw_certs = self._load_raw_data_from_cache(cache_file) + combined_certs = old_raw_certs + new_raw_certs - if self._stop_event and self._stop_event.is_set(): - return ProviderResult() - - # Combine with old data if cache is stale - if cache_status == "stale": - old_raw_certs = self._load_raw_data_from_cache(cache_file) - combined_certs = old_raw_certs + new_raw_certs - - # Deduplicate the combined list - seen_ids = set() - unique_certs = [] - for cert in combined_certs: - cert_id = cert.get('id') - if cert_id not in seen_ids: - unique_certs.append(cert) - seen_ids.add(cert_id) - - raw_certificates_to_process = unique_certs - self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}") - else: # "not_found" - raw_certificates_to_process = new_raw_certs + # Deduplicate the combined list + seen_ids = set() + unique_certs = [] + for cert in combined_certs: + cert_id = cert.get('id') + if cert_id not in seen_ids: + unique_certs.append(cert) + seen_ids.add(cert_id) - # FIXED: Process certificates to create proper domain and CA nodes - result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process) - self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)") + raw_certificates_to_process = unique_certs + self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}") + else: # "not_found" + raw_certificates_to_process = new_raw_certs + + # FIXED: Process certificates to create proper domain and CA nodes + result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process) + self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)") - # Save the new result and the raw data to the cache - self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain) - - except (requests.exceptions.RequestException, psycopg2.Error) as e: - self.logger.logger.error(f"Upstream query failed for {domain}: {e}") - if cache_status != "not_found": - result = self._load_from_cache(cache_file) - self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.") - else: - raise e # Re-raise if there's no cache to fall back on + # Save the new result and the raw data to the cache + self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain) return result @@ -277,58 +247,6 @@ class CrtShProvider(BaseProvider): json.dump(cache_data, f, separators=(',', ':'), default=str) except Exception as e: self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}") - - def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]: - """Query crt.sh, trying the database first and falling back to the API.""" - global db_pool - if db_pool: - try: - self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}") - return self._query_crtsh_db(domain) - except psycopg2.Error as e: - self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.") - return self._query_crtsh_api(domain) - else: - self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}") - return self._query_crtsh_api(domain) - - def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]: - """Query crt.sh database for raw certificate data.""" - global db_pool - conn = db_pool.getconn() - try: - with conn.cursor() as cursor: - query = """ - SELECT - c.id, - x509_serialnumber(c.certificate) as serial_number, - x509_notbefore(c.certificate) as not_before, - x509_notafter(c.certificate) as not_after, - c.issuer_ca_id, - ca.name as issuer_name, - x509_commonname(c.certificate) as common_name, - identities(c.certificate)::text as name_value - FROM certificate c - LEFT JOIN ca ON c.issuer_ca_id = ca.id - WHERE identities(c.certificate) @@ plainto_tsquery(%s) - ORDER BY c.id DESC - LIMIT 5000; - """ - cursor.execute(query, (domain,)) - - results = [] - columns = [desc[0] for desc in cursor.description] - for row in cursor.fetchall(): - row_dict = dict(zip(columns, row)) - if row_dict.get('not_before'): - row_dict['not_before'] = row_dict['not_before'].isoformat() - if row_dict.get('not_after'): - row_dict['not_after'] = row_dict['not_after'].isoformat() - results.append(row_dict) - self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.") - return results - finally: - db_pool.putconn(conn) def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]: """Query crt.sh API for raw certificate data.""" diff --git a/providers/dns_provider.py b/providers/dns_provider.py index 3aef192..12a91ff 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -27,6 +27,21 @@ class DNSProvider(BaseProvider): self.resolver.timeout = 5 self.resolver.lifetime = 10 + def __getstate__(self): + """Prepare the object for pickling.""" + state = self.__dict__.copy() + # Remove the unpickleable 'resolver' attribute + if 'resolver' in state: + del state['resolver'] + return state + + def __setstate__(self, state): + """Restore the object after unpickling.""" + self.__dict__.update(state) + # Re-initialize the 'resolver' attribute + self.resolver = resolver.Resolver() + self.resolver.timeout = 5 + def get_name(self) -> str: """Return the provider name.""" return "dns" diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 05dfc6c..26435b8 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -36,6 +36,15 @@ class ShodanProvider(BaseProvider): self.cache_dir = Path('cache') / 'shodan' self.cache_dir.mkdir(parents=True, exist_ok=True) + def __getstate__(self): + """Prepare the object for pickling.""" + state = super().__getstate__() + return state + + def __setstate__(self, state): + """Restore the object after unpickling.""" + super().__setstate__(state) + def _check_api_connection(self) -> bool: """ FIXED: Lazy connection checking - only test when actually needed. diff --git a/requirements.txt b/requirements.txt index 4ec5adb..964baa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,6 @@ dnspython gunicorn redis python-dotenv -psycopg2-binary \ No newline at end of file +psycopg2-binary +Flask-SocketIO +eventlet \ No newline at end of file diff --git a/static/js/main.js b/static/js/main.js index cc0c148..f8ee132 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -8,8 +8,8 @@ class DNSReconApp { constructor() { console.log('DNSReconApp constructor called'); this.graphManager = null; + this.socket = null; this.scanStatus = 'idle'; - this.pollInterval = null; this.currentSessionId = null; this.elements = {}; @@ -31,13 +31,11 @@ class DNSReconApp { this.initializeElements(); this.setupEventHandlers(); this.initializeGraph(); - this.updateStatus(); + this.initializeSocket(); this.loadProviders(); this.initializeEnhancedModals(); this.addCheckboxStyling(); - this.updateGraph(); - console.log('DNSRecon application initialized successfully'); } catch (error) { console.error('Failed to initialize DNSRecon application:', error); @@ -45,6 +43,25 @@ class DNSReconApp { } }); } + + initializeSocket() { + this.socket = io(); + + this.socket.on('connect', () => { + console.log('Connected to WebSocket server'); + this.updateConnectionStatus('idle'); + this.socket.emit('get_status'); + }); + + this.socket.on('scan_update', (data) => { + if (data.status !== this.scanStatus) { + this.handleStatusChange(data.status, data.task_queue_size); + } + this.scanStatus = data.status; + this.updateStatusDisplay(data); + this.graphManager.updateGraph(data.graph); + }); + } /** * Initialize DOM element references @@ -328,15 +345,8 @@ class DNSReconApp { console.log(`Scan started for ${target} with depth ${maxDepth}`); - // Start polling immediately with faster interval for responsiveness - this.startPolling(1000); - - // Force an immediate status update - console.log('Forcing immediate status update...'); - setTimeout(() => { - this.updateStatus(); - this.updateGraph(); - }, 100); + // Request initial status update via WebSocket + this.socket.emit('get_status'); } else { throw new Error(response.error || 'Failed to start scan'); @@ -368,22 +378,6 @@ class DNSReconApp { if (response.success) { this.showSuccess('Scan stop requested'); - - // Force immediate status update - setTimeout(() => { - this.updateStatus(); - }, 100); - - // Continue polling for a bit to catch the status change - this.startPolling(500); // Fast polling to catch status change - - // Stop fast polling after 10 seconds - setTimeout(() => { - if (this.scanStatus === 'stopped' || this.scanStatus === 'idle') { - this.stopPolling(); - } - }, 10000); - } else { throw new Error(response.error || 'Failed to stop scan'); } @@ -548,68 +542,6 @@ class DNSReconApp { } } - /** - * Start polling for scan updates with configurable interval - */ - startPolling(interval = 2000) { - console.log('=== STARTING POLLING ==='); - - if (this.pollInterval) { - console.log('Clearing existing poll interval'); - clearInterval(this.pollInterval); - } - - this.pollInterval = setInterval(() => { - this.updateStatus(); - this.updateGraph(); - this.loadProviders(); - }, interval); - - console.log(`Polling started with ${interval}ms interval`); - } - - /** - * Stop polling for updates - */ - stopPolling() { - console.log('=== STOPPING POLLING ==='); - if (this.pollInterval) { - clearInterval(this.pollInterval); - this.pollInterval = null; - } - } - - /** - * Status update with better error handling - */ - async updateStatus() { - try { - const response = await this.apiCall('/api/scan/status'); - - - if (response.success && response.status) { - const status = response.status; - - this.updateStatusDisplay(status); - - // Handle status changes - if (status.status !== this.scanStatus) { - console.log(`*** STATUS CHANGED: ${this.scanStatus} -> ${status.status} ***`); - this.handleStatusChange(status.status, status.task_queue_size); - } - - this.scanStatus = status.status; - } else { - console.error('Status update failed:', response); - // Don't show error for status updates to avoid spam - } - - } catch (error) { - console.error('Failed to update status:', error); - this.showConnectionError(); - } - } - /** * Update graph from server */ @@ -737,25 +669,20 @@ class DNSReconApp { case 'running': this.setUIState('scanning', task_queue_size); this.showSuccess('Scan is running'); - // Increase polling frequency for active scans - this.startPolling(1000); // Poll every 1 second for running scans this.updateConnectionStatus('active'); break; case 'completed': this.setUIState('completed', task_queue_size); - this.stopPolling(); this.showSuccess('Scan completed successfully'); this.updateConnectionStatus('completed'); this.loadProviders(); // Force a final graph update console.log('Scan completed - forcing final graph update'); - setTimeout(() => this.updateGraph(), 100); break; case 'failed': this.setUIState('failed', task_queue_size); - this.stopPolling(); this.showError('Scan failed'); this.updateConnectionStatus('error'); this.loadProviders(); @@ -763,7 +690,6 @@ class DNSReconApp { case 'stopped': this.setUIState('stopped', task_queue_size); - this.stopPolling(); this.showSuccess('Scan stopped'); this.updateConnectionStatus('stopped'); this.loadProviders(); @@ -771,7 +697,6 @@ class DNSReconApp { case 'idle': this.setUIState('idle', task_queue_size); - this.stopPolling(); this.updateConnectionStatus('idle'); break; @@ -2033,10 +1958,10 @@ class DNSReconApp { // If the scanner was idle, it's now running. Start polling to see the new node appear. if (this.scanStatus === 'idle') { - this.startPolling(1000); + this.socket.emit('get_status'); } else { // If already scanning, force a quick graph update to see the change sooner. - setTimeout(() => this.updateGraph(), 500); + setTimeout(() => this.socket.emit('get_status'), 500); } } else { diff --git a/templates/index.html b/templates/index.html index 776dc6a..bedeb5f 100644 --- a/templates/index.html +++ b/templates/index.html @@ -7,6 +7,7 @@ DNSRecon - Infrastructure Reconnaissance +