diff --git a/app.py b/app.py index b987244..af22847 100644 --- a/app.py +++ b/app.py @@ -1,20 +1,65 @@ """ Flask application entry point for DNSRecon web interface. -Provides REST API endpoints and serves the web interface. +Provides REST API endpoints and serves the web interface with user session support. +Enhanced with better session debugging and isolation. """ import json import traceback -from flask import Flask, render_template, request, jsonify, send_file -from datetime import datetime, timezone +from flask import Flask, render_template, request, jsonify, send_file, session +from datetime import datetime, timezone, timedelta import io -from core.scanner import scanner +from core.session_manager import session_manager from config import config app = Flask(__name__) app.config['SECRET_KEY'] = 'dnsrecon-dev-key-change-in-production' +app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2) # 2 hour session lifetime + +def get_user_scanner(): + """ + Get or create scanner instance for current user session with enhanced debugging. + + Returns: + Tuple of (session_id, scanner_instance) + """ + # Get current Flask session info for debugging + current_flask_session_id = session.get('dnsrecon_session_id') + client_ip = request.remote_addr + user_agent = request.headers.get('User-Agent', '')[:100] # Truncate for logging + + print("=== SESSION DEBUG ===") + print(f"Client IP: {client_ip}") + print(f"User Agent: {user_agent}") + print(f"Flask Session ID: {current_flask_session_id}") + print(f"Flask Session Keys: {list(session.keys())}") + + # Try to get existing session + if current_flask_session_id: + existing_scanner = session_manager.get_session(current_flask_session_id) + if existing_scanner: + print(f"Using existing session: {current_flask_session_id}") + print(f"Scanner status: {existing_scanner.status}") + return current_flask_session_id, existing_scanner + else: + print(f"Session {current_flask_session_id} not found in session manager") + + # Create new session + print("Creating new session...") + new_session_id = session_manager.create_session() + new_scanner = session_manager.get_session(new_session_id) + + # Store in Flask session + session['dnsrecon_session_id'] = new_session_id + session.permanent = True + + print(f"Created new session: {new_session_id}") + print(f"New scanner status: {new_scanner.status}") + print("=== END SESSION DEBUG ===") + + return new_session_id, new_scanner @app.route('/') @@ -26,13 +71,8 @@ def index(): @app.route('/api/scan/start', methods=['POST']) def start_scan(): """ - Start a new reconnaissance scan. - - Expects JSON payload: - { - "target_domain": "example.com", - "max_depth": 2 - } + Start a new reconnaissance scan for the current user session. + Enhanced with better error handling and debugging. """ print("=== API: /api/scan/start called ===") @@ -68,26 +108,62 @@ def start_scan(): 'error': 'Max depth must be an integer between 1 and 5' }), 400 - print("Validation passed, calling scanner.start_scan...") + print("Validation passed, getting user scanner...") + + # Get user-specific scanner with enhanced debugging + user_session_id, scanner = get_user_scanner() + print(f"Using session: {user_session_id}") + print(f"Scanner object ID: {id(scanner)}") + print(f"Scanner status before start: {scanner.status}") + + # Additional safety check - if scanner is somehow in running state, force reset + if scanner.status == 'running': + print(f"WARNING: Scanner in session {user_session_id} was already running - forcing reset") + scanner.stop_scan() + # Give it a moment to stop + import time + time.sleep(1) + + # If still running, force status reset + if scanner.status == 'running': + print("WARNING: Force resetting scanner status from 'running' to 'idle'") + scanner.status = 'idle' # Start scan + print(f"Calling start_scan on scanner {id(scanner)}...") success = scanner.start_scan(target_domain, max_depth) print(f"scanner.start_scan returned: {success}") + print(f"Scanner status after start attempt: {scanner.status}") if success: - session_id = scanner.logger.session_id - print(f"Scan started successfully with session ID: {session_id}") + scan_session_id = scanner.logger.session_id + print(f"Scan started successfully with scan session ID: {scan_session_id}") return jsonify({ 'success': True, 'message': 'Scan started successfully', - 'scan_id': session_id + 'scan_id': scan_session_id, + 'user_session_id': user_session_id, + 'debug_info': { + 'scanner_object_id': id(scanner), + 'scanner_status': scanner.status + } }) else: print("ERROR: Scanner returned False") + + # Provide more detailed error information + error_details = { + 'scanner_status': scanner.status, + 'scanner_object_id': id(scanner), + 'session_id': user_session_id, + 'providers_count': len(scanner.providers) if hasattr(scanner, 'providers') else 0 + } + return jsonify({ 'success': False, - 'error': 'Failed to start scan (scan may already be running)' + 'error': f'Failed to start scan (scanner status: {scanner.status})', + 'debug_info': error_details }), 409 except Exception as e: @@ -98,24 +174,28 @@ def start_scan(): 'error': f'Internal server error: {str(e)}' }), 500 - @app.route('/api/scan/stop', methods=['POST']) def stop_scan(): - """Stop the current scan.""" + """Stop the current scan for the user session.""" print("=== API: /api/scan/stop called ===") try: + # Get user-specific scanner + user_session_id, scanner = get_user_scanner() + print(f"Stopping scan for session: {user_session_id}") + success = scanner.stop_scan() if success: return jsonify({ 'success': True, - 'message': 'Scan stop requested' + 'message': 'Scan stop requested', + 'user_session_id': user_session_id }) else: return jsonify({ 'success': False, - 'error': 'No active scan to stop' + 'error': 'No active scan to stop for this session' }), 400 except Exception as e: @@ -129,9 +209,14 @@ def stop_scan(): @app.route('/api/scan/status', methods=['GET']) def get_scan_status(): - """Get current scan status and progress.""" + """Get current scan status and progress for the user session.""" try: + # Get user-specific scanner + user_session_id, scanner = get_user_scanner() + status = scanner.get_scan_status() + status['user_session_id'] = user_session_id + return jsonify({ 'success': True, 'status': status @@ -148,12 +233,16 @@ def get_scan_status(): @app.route('/api/graph', methods=['GET']) def get_graph_data(): - """Get current graph data for visualization.""" + """Get current graph data for visualization for the user session.""" try: + # Get user-specific scanner + user_session_id, scanner = get_user_scanner() + graph_data = scanner.get_graph_data() return jsonify({ 'success': True, - 'graph': graph_data + 'graph': graph_data, + 'user_session_id': user_session_id }) except Exception as e: @@ -167,15 +256,25 @@ def get_graph_data(): @app.route('/api/export', methods=['GET']) def export_results(): - """Export complete scan results as downloadable JSON.""" + """Export complete scan results as downloadable JSON for the user session.""" try: + # Get user-specific scanner + user_session_id, scanner = get_user_scanner() + # Get complete results results = scanner.export_results() + # Add session information to export + results['export_metadata'] = { + 'user_session_id': user_session_id, + 'export_timestamp': datetime.now(timezone.utc).isoformat(), + 'export_type': 'user_session_results' + } + # Create filename with timestamp timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S') target = scanner.current_target or 'unknown' - filename = f"dnsrecon_{target}_{timestamp}.json" + filename = f"dnsrecon_{target}_{timestamp}_{user_session_id[:8]}.json" # Create in-memory file json_data = json.dumps(results, indent=2, ensure_ascii=False) @@ -199,10 +298,13 @@ def export_results(): @app.route('/api/providers', methods=['GET']) def get_providers(): - """Get information about available providers.""" + """Get information about available providers for the user session.""" print("=== API: /api/providers called ===") try: + # Get user-specific scanner + user_session_id, scanner = get_user_scanner() + provider_stats = scanner.get_provider_statistics() # Add configuration information @@ -215,10 +317,11 @@ def get_providers(): 'requires_api_key': provider_name in ['shodan', 'virustotal'] } - print(f"Returning provider info: {list(provider_info.keys())}") + print(f"Returning provider info for session {user_session_id}: {list(provider_info.keys())}") return jsonify({ 'success': True, - 'providers': provider_info + 'providers': provider_info, + 'user_session_id': user_session_id }) except Exception as e: @@ -233,13 +336,7 @@ def get_providers(): @app.route('/api/config/api-keys', methods=['POST']) def set_api_keys(): """ - Set API keys for providers (stored in memory only). - - Expects JSON payload: - { - "shodan": "api_key_here", - "virustotal": "api_key_here" - } + Set API keys for providers for the user session only. """ try: data = request.get_json() @@ -250,22 +347,27 @@ def set_api_keys(): 'error': 'No API keys provided' }), 400 + # Get user-specific scanner and config + user_session_id, scanner = get_user_scanner() + session_config = scanner.config + updated_providers = [] for provider, api_key in data.items(): if provider in ['shodan', 'virustotal'] and api_key.strip(): - success = config.set_api_key(provider, api_key.strip()) + success = session_config.set_api_key(provider, api_key.strip()) if success: updated_providers.append(provider) if updated_providers: - # Reinitialize scanner providers + # Reinitialize scanner providers for this session only scanner._initialize_providers() return jsonify({ 'success': True, - 'message': f'API keys updated for: {", ".join(updated_providers)}', - 'updated_providers': updated_providers + 'message': f'API keys updated for session {user_session_id}: {", ".join(updated_providers)}', + 'updated_providers': updated_providers, + 'user_session_id': user_session_id }) else: return jsonify({ @@ -280,27 +382,120 @@ def set_api_keys(): 'success': False, 'error': f'Internal server error: {str(e)}' }), 500 + + except Exception as e: + print(f"ERROR: Exception in set_api_keys endpoint: {e}") + traceback.print_exc() + return jsonify({ + 'success': False, + 'error': f'Internal server error: {str(e)}' + }), 500 + + +@app.route('/api/session/info', methods=['GET']) +def get_session_info(): + """Get information about the current user session.""" + try: + user_session_id, scanner = get_user_scanner() + session_info = session_manager.get_session_info(user_session_id) + + return jsonify({ + 'success': True, + 'session_info': session_info + }) + + except Exception as e: + print(f"ERROR: Exception in get_session_info endpoint: {e}") + traceback.print_exc() + return jsonify({ + 'success': False, + 'error': f'Internal server error: {str(e)}' + }), 500 + + +@app.route('/api/session/terminate', methods=['POST']) +def terminate_session(): + """Terminate the current user session.""" + try: + user_session_id = session.get('dnsrecon_session_id') + + if user_session_id: + success = session_manager.terminate_session(user_session_id) + # Clear Flask session + session.pop('dnsrecon_session_id', None) + + return jsonify({ + 'success': success, + 'message': 'Session terminated' if success else 'Session not found' + }) + else: + return jsonify({ + 'success': False, + 'error': 'No active session to terminate' + }), 400 + + except Exception as e: + print(f"ERROR: Exception in terminate_session endpoint: {e}") + traceback.print_exc() + return jsonify({ + 'success': False, + 'error': f'Internal server error: {str(e)}' + }), 500 + + +@app.route('/api/admin/sessions', methods=['GET']) +def list_sessions(): + """Admin endpoint to list all active sessions.""" + try: + sessions = session_manager.list_active_sessions() + stats = session_manager.get_statistics() + + return jsonify({ + 'success': True, + 'sessions': sessions, + 'statistics': stats + }) + + except Exception as e: + print(f"ERROR: Exception in list_sessions endpoint: {e}") + traceback.print_exc() + return jsonify({ + 'success': False, + 'error': f'Internal server error: {str(e)}' + }), 500 @app.route('/api/health', methods=['GET']) def health_check(): """Health check endpoint with enhanced Phase 2 information.""" - return jsonify({ - 'success': True, - 'status': 'healthy', - 'timestamp': datetime.now(datetime.UTC).isoformat(), - 'version': '1.0.0-phase2', - 'phase': 2, - 'features': { - 'multi_provider': True, - 'concurrent_processing': True, - 'real_time_updates': True, - 'api_key_management': True, - 'enhanced_visualization': True, - 'retry_logic': True - }, - 'providers_available': len(scanner.providers) if hasattr(scanner, 'providers') else 0 - }) + try: + # Get session stats + session_stats = session_manager.get_statistics() + + return jsonify({ + 'success': True, + 'status': 'healthy', + 'timestamp': datetime.now(timezone.utc).isoformat(), + 'version': '1.0.0-phase2', + 'phase': 2, + 'features': { + 'multi_provider': True, + 'concurrent_processing': True, + 'real_time_updates': True, + 'api_key_management': True, + 'enhanced_visualization': True, + 'retry_logic': True, + 'user_sessions': True, + 'session_isolation': True + }, + 'session_statistics': session_stats + }) + except Exception as e: + print(f"ERROR: Exception in health_check endpoint: {e}") + return jsonify({ + 'success': False, + 'error': f'Health check failed: {str(e)}' + }), 500 @app.errorhandler(404) @@ -324,7 +519,7 @@ def internal_error(error): if __name__ == '__main__': - print("Starting DNSRecon Flask application...") + print("Starting DNSRecon Flask application with user session support...") # Load configuration from environment config.load_from_env() diff --git a/core/__init__.py b/core/__init__.py index 8d3e34d..bacc384 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -5,8 +5,10 @@ Phase 2: Enhanced with concurrent processing and real-time capabilities. """ from .graph_manager import GraphManager, NodeType, RelationshipType -from .scanner import Scanner, ScanStatus, scanner +from .scanner import Scanner, ScanStatus # Remove 'scanner' global instance from .logger import ForensicLogger, get_forensic_logger, new_session +from .session_manager import session_manager # Add session manager +from .session_config import SessionConfig, create_session_config # Add session config __all__ = [ 'GraphManager', @@ -14,10 +16,13 @@ __all__ = [ 'RelationshipType', 'Scanner', 'ScanStatus', - 'scanner', + # 'scanner', # Remove this - no more global scanner 'ForensicLogger', 'get_forensic_logger', - 'new_session' + 'new_session', + 'session_manager', # Add this + 'SessionConfig', # Add this + 'create_session_config' # Add this ] __version__ = "1.0.0-phase2" \ No newline at end of file diff --git a/core/graph_manager.py b/core/graph_manager.py index e838484..b28bc13 100644 --- a/core/graph_manager.py +++ b/core/graph_manager.py @@ -3,13 +3,10 @@ Graph data model for DNSRecon using NetworkX. Manages in-memory graph storage with confidence scoring and forensic metadata. """ -import json -import threading from datetime import datetime -from typing import Dict, List, Any, Optional, Tuple, Set +from typing import Dict, List, Any, Optional, Tuple from enum import Enum from datetime import timezone -from collections import defaultdict import networkx as nx @@ -18,8 +15,8 @@ class NodeType(Enum): """Enumeration of supported node types.""" DOMAIN = "domain" IP = "ip" - CERTIFICATE = "certificate" ASN = "asn" + DNS_RECORD = "dns_record" LARGE_ENTITY = "large_entity" @@ -43,6 +40,7 @@ class RelationshipType(Enum): TLSA_RECORD = ("tlsa_record", 0.7) NAPTR_RECORD = ("naptr_record", 0.7) SPF_RECORD = ("spf_record", 0.7) + DNS_RECORD = ("dns_record", 0.8) PASSIVE_DNS = ("passive_dns", 0.6) ASN_MEMBERSHIP = ("asn", 0.7) @@ -115,8 +113,7 @@ class GraphManager: Returns: bool: True if edge was added, False if it already exists """ - #with self.lock: - # Ensure both nodes exist + if not self.graph.has_node(source_id) or not self.graph.has_node(target_id): # If the target node is a subdomain, it should be added. # The scanner will handle this logic. @@ -149,12 +146,10 @@ class GraphManager: def get_node_count(self) -> int: """Get total number of nodes in the graph.""" - #with self.lock: return self.graph.number_of_nodes() def get_edge_count(self) -> int: """Get total number of edges in the graph.""" - #with self.lock: return self.graph.number_of_edges() def get_nodes_by_type(self, node_type: NodeType) -> List[str]: @@ -167,7 +162,6 @@ class GraphManager: Returns: List of node identifiers """ - #with self.lock: return [ node_id for node_id, attributes in self.graph.nodes(data=True) if attributes.get('type') == node_type.value @@ -183,7 +177,6 @@ class GraphManager: Returns: List of neighboring node identifiers """ - #with self.lock: if not self.graph.has_node(node_id): return [] @@ -201,7 +194,6 @@ class GraphManager: Returns: List of tuples (source, target, attributes) """ - #with self.lock: return [ (source, target, attributes) for source, target, attributes in self.graph.edges(data=True) @@ -211,46 +203,12 @@ class GraphManager: def get_graph_data(self) -> Dict[str, Any]: """ Export graph data for visualization. - - Returns: - Dictionary containing nodes and edges for frontend visualization + Uses comprehensive metadata collected during scanning. """ - #with self.lock: nodes = [] edges = [] - # Create a dictionary to hold aggregated data for each node - node_details = defaultdict(lambda: defaultdict(list)) - - for source, target, attributes in self.graph.edges(data=True): - provider = attributes.get('source_provider', 'unknown') - raw_data = attributes.get('raw_data', {}) - - if provider == 'dns': - record_type = raw_data.get('query_type', 'UNKNOWN') - value = raw_data.get('value', target) - # DNS data is always about the source node of the query - node_details[source]['dns_records'].append(f"{record_type}: {value}") - - elif provider == 'crtsh': - # Data from crt.sh are domain names found in certificates (SANs) - node_details[source]['related_domains_san'].append(target) - - elif provider == 'shodan': - # Shodan data is about the IP, which can be either the source or target - source_node_type = self.graph.nodes[source].get('type') - target_node_type = self.graph.nodes[target].get('type') - - if source_node_type == 'ip': - node_details[source]['shodan'] = raw_data - elif target_node_type == 'ip': - node_details[target]['shodan'] = raw_data - - elif provider == 'virustotal': - # VirusTotal data is about the source node of the query - node_details[source]['virustotal'] = raw_data - - # Format nodes for visualization + # Create nodes with the comprehensive metadata already collected for node_id, attributes in self.graph.nodes(data=True): node_data = { 'id': node_id, @@ -260,18 +218,15 @@ class GraphManager: 'added_timestamp': attributes.get('added_timestamp') } - # Add the aggregated details to the metadata - if node_id in node_details: - for key, value in node_details[node_id].items(): - # Use a set to avoid adding duplicate entries to lists - if key in node_data['metadata'] and isinstance(node_data['metadata'][key], list): - existing_values = set(node_data['metadata'][key]) - new_values = [v for v in value if v not in existing_values] - node_data['metadata'][key].extend(new_values) - else: - node_data['metadata'][key] = value + # Handle certificate node labeling + if node_id.startswith('cert_'): + # For certificate nodes, create a more informative label + cert_metadata = node_data['metadata'] + issuer = cert_metadata.get('issuer_name', 'Unknown') + valid_status = "✓" if cert_metadata.get('is_currently_valid') else "✗" + node_data['label'] = f"Certificate {valid_status}\n{issuer[:30]}..." - # Color coding by type - now returns color objects for enhanced visualization + # Color coding by type type_colors = { 'domain': { 'background': '#00ff41', @@ -285,18 +240,18 @@ class GraphManager: 'highlight': {'background': '#ffbb44', 'border': '#ff9900'}, 'hover': {'background': '#ffaa22', 'border': '#dd8800'} }, - 'certificate': { - 'background': '#c7c7c7', - 'border': '#999999', - 'highlight': {'background': '#e0e0e0', 'border': '#c7c7c7'}, - 'hover': {'background': '#d4d4d4', 'border': '#aaaaaa'} - }, 'asn': { 'background': '#00aaff', 'border': '#0088cc', 'highlight': {'background': '#44ccff', 'border': '#00aaff'}, 'hover': {'background': '#22bbff', 'border': '#0099dd'} }, + 'dns_record': { + 'background': '#9d4edd', + 'border': '#7b2cbf', + 'highlight': {'background': '#c77dff', 'border': '#9d4edd'}, + 'hover': {'background': '#b392f0', 'border': '#8b5cf6'} + }, 'large_entity': { 'background': '#ff6b6b', 'border': '#cc3a3a', @@ -306,15 +261,17 @@ class GraphManager: } node_color_config = type_colors.get(attributes.get('type', 'unknown'), type_colors['domain']) + node_data['color'] = node_color_config - # Pass the has_valid_cert metadata for styling - if 'metadata' in attributes and 'has_valid_cert' in attributes['metadata']: - node_data['has_valid_cert'] = attributes['metadata']['has_valid_cert'] + # Add certificate validity indicator if available + metadata = node_data['metadata'] + if 'certificate_data' in metadata and 'has_valid_cert' in metadata['certificate_data']: + node_data['has_valid_cert'] = metadata['certificate_data']['has_valid_cert'] nodes.append(node_data) - # Format edges for visualization + # Create edges (unchanged from original) for source, target, attributes in self.graph.edges(data=True): edge_data = { 'from': source, @@ -376,7 +333,6 @@ class GraphManager: Returns: Dictionary containing complete graph data with metadata """ - #with self.lock: # Get basic graph data graph_data = self.get_graph_data() @@ -427,7 +383,6 @@ class GraphManager: Returns: Dictionary containing various graph metrics """ - #with self.lock: stats = { 'basic_metrics': { 'total_nodes': self.graph.number_of_nodes(), @@ -462,7 +417,6 @@ class GraphManager: def clear(self) -> None: """Clear all nodes and edges from the graph.""" - #with self.lock: self.graph.clear() self.creation_time = datetime.now(timezone.utc).isoformat() self.last_modified = self.creation_time \ No newline at end of file diff --git a/core/logger.py b/core/logger.py index c9f6964..f11f70c 100644 --- a/core/logger.py +++ b/core/logger.py @@ -3,7 +3,6 @@ Forensic logging system for DNSRecon tool. Provides structured audit trail for all reconnaissance activities. """ -import json import logging import threading from datetime import datetime @@ -109,7 +108,6 @@ class ForensicLogger: target_indicator: The indicator being investigated discovery_context: Context of how this indicator was discovered """ - #with self.lock: api_request = APIRequest( timestamp=datetime.now(timezone.utc).isoformat(), provider=provider, @@ -152,7 +150,6 @@ class ForensicLogger: raw_data: Raw data from provider response discovery_method: Method used to discover relationship """ - #with self.lock: relationship = RelationshipDiscovery( timestamp=datetime.now(timezone.utc).isoformat(), source_node=source_node, @@ -178,12 +175,10 @@ class ForensicLogger: self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}") self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}") - #with self.lock: self.session_metadata['target_domains'].add(target_domain) def log_scan_complete(self) -> None: """Log the completion of a reconnaissance scan.""" - #with self.lock: self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat() self.session_metadata['providers_used'] = list(self.session_metadata['providers_used']) self.session_metadata['target_domains'] = list(self.session_metadata['target_domains']) @@ -199,7 +194,6 @@ class ForensicLogger: Returns: Dictionary containing complete session audit trail """ - #with self.lock: return { 'session_metadata': self.session_metadata.copy(), 'api_requests': [asdict(req) for req in self.api_requests], @@ -214,7 +208,6 @@ class ForensicLogger: Returns: Dictionary containing summary statistics """ - #with self.lock: provider_stats = {} for provider in self.session_metadata['providers_used']: provider_requests = [req for req in self.api_requests if req.provider == provider] diff --git a/core/scanner.py b/core/scanner.py index 18f69da..01564fa 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -4,14 +4,14 @@ Coordinates data gathering from multiple providers and builds the infrastructure """ import threading -import time import traceback -from typing import List, Set, Dict, Any, Optional, Tuple +from typing import List, Set, Dict, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError from collections import defaultdict from core.graph_manager import GraphManager, NodeType, RelationshipType from core.logger import get_forensic_logger, new_session +from utils.helpers import _is_valid_ip, _is_valid_domain from providers.crtsh_provider import CrtShProvider from providers.dns_provider import DNSProvider from providers.shodan_provider import ShodanProvider @@ -31,21 +31,27 @@ class ScanStatus: class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. - Manages multi-provider data gathering and graph construction with concurrent processing. + Now supports per-session configuration for multi-user isolation. """ - def __init__(self): - """Initialize scanner with all available providers and empty graph.""" + def __init__(self, session_config=None): + """Initialize scanner with session-specific configuration.""" print("Initializing Scanner instance...") try: + # Use provided session config or create default + if session_config is None: + from core.session_config import create_session_config + session_config = create_session_config() + + self.config = session_config self.graph = GraphManager() self.providers = [] self.status = ScanStatus.IDLE self.current_target = None self.current_depth = 0 self.max_depth = 2 - self.stop_event = threading.Event() # Use a threading.Event for safer signaling + self.stop_event = threading.Event() self.scan_thread = None # Scanning progress tracking @@ -54,11 +60,11 @@ class Scanner: self.current_indicator = "" # Concurrent processing configuration - self.max_workers = config.max_concurrent_requests - self.executor = None # Keep a reference to the executor + self.max_workers = self.config.max_concurrent_requests + self.executor = None - # Initialize providers - print("Calling _initialize_providers...") + # Initialize providers with session config + print("Calling _initialize_providers with session config...") self._initialize_providers() # Initialize logger @@ -73,10 +79,10 @@ class Scanner: raise def _initialize_providers(self) -> None: - """Initialize all available providers based on configuration.""" + """Initialize all available providers based on session configuration.""" self.providers = [] - print("Initializing providers...") + print("Initializing providers with session config...") # Always add free providers free_providers = [ @@ -85,12 +91,15 @@ class Scanner: ] for provider_name, provider_class in free_providers: - if config.is_provider_enabled(provider_name): + if self.config.is_provider_enabled(provider_name): try: - provider = provider_class() + # Pass session config to provider + provider = provider_class(session_config=self.config) if provider.is_available(): + # Set the stop event for cancellation support + provider.set_stop_event(self.stop_event) self.providers.append(provider) - print(f"✓ {provider_name.title()} provider initialized successfully") + print(f"✓ {provider_name.title()} provider initialized successfully for session") else: print(f"✗ {provider_name.title()} provider is not available") except Exception as e: @@ -104,23 +113,41 @@ class Scanner: ] for provider_name, provider_class in api_providers: - if config.is_provider_enabled(provider_name): + if self.config.is_provider_enabled(provider_name): try: - provider = provider_class() + # Pass session config to provider + provider = provider_class(session_config=self.config) if provider.is_available(): + # Set the stop event for cancellation support + provider.set_stop_event(self.stop_event) self.providers.append(provider) - print(f"✓ {provider_name.title()} provider initialized successfully") + print(f"✓ {provider_name.title()} provider initialized successfully for session") else: print(f"✗ {provider_name.title()} provider is not available (API key required)") except Exception as e: print(f"✗ Failed to initialize {provider_name.title()} provider: {e}") traceback.print_exc() - print(f"Initialized {len(self.providers)} providers") + print(f"Initialized {len(self.providers)} providers for session") + + def update_session_config(self, new_config) -> None: + """ + Update session configuration and reinitialize providers. + + Args: + new_config: New SessionConfig instance + """ + print("Updating session configuration...") + self.config = new_config + self.max_workers = self.config.max_concurrent_requests + self._initialize_providers() + print("Session configuration updated") + def start_scan(self, target_domain: str, max_depth: int = 2) -> bool: """ Start a new reconnaissance scan with concurrent processing. + Enhanced with better debugging and state validation. Args: target_domain: Initial domain to investigate @@ -129,29 +156,36 @@ class Scanner: Returns: bool: True if scan started successfully """ - print(f"Scanner.start_scan called with target='{target_domain}', depth={max_depth}") - + print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") + print(f"Scanner status: {self.status}") + print(f"Target domain: '{target_domain}', Max depth: {max_depth}") + print(f"Available providers: {len(self.providers) if hasattr(self, 'providers') else 0}") + try: if self.status == ScanStatus.RUNNING: - print("Scan already running, rejecting new scan") + print(f"ERROR: Scan already running in scanner {id(self)}, rejecting new scan") + print(f"Current target: {self.current_target}") + print(f"Current depth: {self.current_depth}") return False # Check if we have any providers - if not self.providers: - print("No providers available, cannot start scan") + if not hasattr(self, 'providers') or not self.providers: + print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan") return False + print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}") + # Stop any existing scan thread if self.scan_thread and self.scan_thread.is_alive(): - print("Stopping existing scan thread...") + print(f"Stopping existing scan thread in scanner {id(self)}...") self.stop_event.set() self.scan_thread.join(timeout=5.0) if self.scan_thread.is_alive(): - print("WARNING: Could not stop existing thread") + print(f"WARNING: Could not stop existing thread in scanner {id(self)}") return False # Reset state - print("Resetting scanner state...") + print(f"Resetting scanner {id(self)} state...") self.graph.clear() self.current_target = target_domain.lower().strip() self.max_depth = max_depth @@ -162,11 +196,11 @@ class Scanner: self.current_indicator = self.current_target # Start new forensic session - print("Starting new forensic session...") + print(f"Starting new forensic session for scanner {id(self)}...") self.logger = new_session() # Start scan in separate thread - print("Starting scan thread...") + print(f"Starting scan thread for scanner {id(self)}...") self.scan_thread = threading.Thread( target=self._execute_scan_async, args=(self.current_target, max_depth), @@ -174,14 +208,15 @@ class Scanner: ) self.scan_thread.start() + print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===") return True except Exception as e: - print(f"ERROR: Exception in start_scan: {e}") + print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") traceback.print_exc() return False - def _execute_scan_async(self, target_domain: str, max_depth: int) -> None: + async def _execute_scan_async(self, target_domain: str, max_depth: int) -> None: """ Execute the reconnaissance scan asynchronously with concurrent provider queries. @@ -210,7 +245,7 @@ class Scanner: processed_domains = set() all_discovered_ips = set() - print(f"Starting BFS exploration...") + print("Starting BFS exploration...") for depth in range(max_depth + 1): if self.stop_event.is_set(): @@ -269,7 +304,7 @@ class Scanner: self.executor.shutdown(wait=False, cancel_futures=True) stats = self.graph.get_statistics() - print(f"Final scan statistics:") + print("Final scan statistics:") print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Domains processed: {len(processed_domains)}") @@ -330,87 +365,243 @@ class Scanner: def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]: """ - Query all enabled providers for information about a domain. + Query all enabled providers for information about a domain and collect comprehensive metadata. + Creates appropriate node types and relationships based on discovered data. """ print(f"Querying {len(self.providers)} providers for domain: {domain}") discovered_domains = set() discovered_ips = set() - relationships_by_type = defaultdict(list) + all_relationships = [] + + # Comprehensive metadata collection for this domain + domain_metadata = { + 'dns_records': [], + 'related_domains_san': [], + 'shodan': {}, + 'virustotal': {}, + 'certificate_data': {}, + 'passive_dns': [], + } if not self.providers or self.stop_event.is_set(): return discovered_domains, discovered_ips + # Query all providers concurrently with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor: future_to_provider = { provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider for provider in self.providers } + for future in as_completed(future_to_provider): if self.stop_event.is_set(): future.cancel() continue + provider = future_to_provider[future] try: relationships = future.result() print(f"Provider {provider.get_name()} returned {len(relationships)} relationships") + # Process relationships and collect metadata for rel in relationships: - relationships_by_type[rel[2]].append(rel) + source, target, rel_type, confidence, raw_data = rel + + # Add provider info to the relationship + enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name()) + all_relationships.append(enhanced_rel) + + # Collect metadata based on provider and relationship type + self._collect_node_metadata(domain, provider.get_name(), rel_type, target, raw_data, domain_metadata) except (Exception, CancelledError) as e: print(f"Provider {provider.get_name()} failed for {domain}: {e}") + # Add the domain node with comprehensive metadata + self.graph.add_node(domain, NodeType.DOMAIN, metadata=domain_metadata) + + # Group relationships by type for large entity handling + relationships_by_type = defaultdict(list) + for source, target, rel_type, confidence, raw_data, provider_name in all_relationships: + relationships_by_type[rel_type].append((source, target, rel_type, confidence, raw_data, provider_name)) + + # Handle large entities (only for SAN certificates currently) for rel_type, relationships in relationships_by_type.items(): if len(relationships) > config.large_entity_threshold and rel_type == RelationshipType.SAN_CERTIFICATE: - self._handle_large_entity(domain, relationships, rel_type, provider.get_name()) + first_provider = relationships[0][5] if relationships else "multiple_providers" + self._handle_large_entity(domain, relationships, rel_type, first_provider) + # Remove these relationships from further processing + all_relationships = [rel for rel in all_relationships if not (rel[2] == rel_type and len(relationships_by_type[rel_type]) > config.large_entity_threshold)] + + # Track DNS records to create (avoid duplicates) + dns_records_to_create = {} + + # Process remaining relationships + for source, target, rel_type, confidence, raw_data, provider_name in all_relationships: + if self.stop_event.is_set(): + break + + # Determine how to handle the target based on relationship type and content + if _is_valid_ip(target): + # Create IP node and relationship + self.graph.add_node(target, NodeType.IP) + + if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data): + print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})") + + # Add to recursion if it's a direct resolution + if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]: + discovered_ips.add(target) + + elif target.startswith('AS') and target[2:].isdigit(): + # Create ASN node and relationship + self.graph.add_node(target, NodeType.ASN) + + if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data): + print(f"Added ASN relationship: {source} -> {target} ({rel_type.relationship_name})") + + elif _is_valid_domain(target): + # Create domain node and relationship + self.graph.add_node(target, NodeType.DOMAIN) + + if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data): + print(f"Added domain relationship: {source} -> {target} ({rel_type.relationship_name})") + + # Add to recursion for specific relationship types + recurse_types = [ + RelationshipType.CNAME_RECORD, + RelationshipType.MX_RECORD, + RelationshipType.SAN_CERTIFICATE, + RelationshipType.NS_RECORD, + RelationshipType.PASSIVE_DNS + ] + if rel_type in recurse_types: + discovered_domains.add(target) + else: - for source, target, rel_type, confidence, raw_data in relationships: - # Determine if the target should create a new node - create_node = rel_type in [ - RelationshipType.A_RECORD, - RelationshipType.AAAA_RECORD, - RelationshipType.CNAME_RECORD, - RelationshipType.MX_RECORD, - RelationshipType.NS_RECORD, - RelationshipType.PTR_RECORD, - RelationshipType.SAN_CERTIFICATE - ] + # Handle DNS record content (TXT, SPF, CAA, etc.) + dns_record_types = [ + RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD, + RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD, + RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD, + RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD, + RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD + ] + + if rel_type in dns_record_types: + # Create normalized DNS record identifier + record_type = rel_type.relationship_name.upper().replace('_RECORD', '') + record_content = target.strip() + + # Create a unique identifier for this DNS record + content_hash = hash(record_content) & 0x7FFFFFFF + dns_record_id = f"{record_type}:{content_hash}" + + # Track this DNS record for creation (avoid duplicates) + if dns_record_id not in dns_records_to_create: + dns_records_to_create[dns_record_id] = { + 'content': record_content, + 'type': record_type, + 'domains': set(), + 'raw_data': raw_data, + 'provider_name': provider_name, + 'confidence': confidence + } + + # Add this domain to the DNS record's domain list + dns_records_to_create[dns_record_id]['domains'].add(source) + + print(f"DNS record tracked: {source} -> {record_type} (content length: {len(record_content)})") + else: + # For other non-infrastructure targets, log but don't create nodes + print(f"Non-infrastructure relationship stored as metadata: {source} - {rel_type.relationship_name}: {target[:100]}") - # Determine if the target should be subject to recursion - recurse = rel_type in [ - RelationshipType.A_RECORD, - RelationshipType.AAAA_RECORD, - RelationshipType.CNAME_RECORD, - RelationshipType.MX_RECORD, - RelationshipType.SAN_CERTIFICATE - ] + # Create DNS record nodes and their relationships + for dns_record_id, record_info in dns_records_to_create.items(): + if self.stop_event.is_set(): + break + + record_metadata = { + 'record_type': record_info['type'], + 'content': record_info['content'], + 'content_hash': dns_record_id.split(':')[1], + 'associated_domains': list(record_info['domains']), + 'source_data': record_info['raw_data'] + } + + # Create the DNS record node + self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata) + + # Connect each domain to this DNS record + for domain_name in record_info['domains']: + if self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD, + record_info['confidence'], record_info['provider_name'], + record_info['raw_data']): + print(f"Added DNS record relationship: {domain_name} -> {dns_record_id}") - if create_node: - target_node_type = NodeType.IP if self._is_valid_ip(target) else NodeType.DOMAIN - self.graph.add_node(target, target_node_type) - if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data): - print(f"Added relationship: {source} -> {target} ({rel_type.relationship_name})") - else: - # For records that don't create nodes, we still want to log the relationship - self.logger.log_relationship_discovery( - source_node=source, - target_node=target, - relationship_type=rel_type.relationship_name, - confidence_score=confidence, - provider=provider.name, - raw_data=raw_data, - discovery_method=f"dns_{rel_type.name.lower()}_record" - ) - - if recurse: - if self._is_valid_ip(target): - discovered_ips.add(target) - elif self._is_valid_domain(target): - discovered_domains.add(target) - - print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs") + print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs, {len(dns_records_to_create)} DNS records") return discovered_domains, discovered_ips + def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType, + target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None: + """ + Collect and organize metadata for a node based on provider responses. + """ + if provider_name == 'dns': + record_type = raw_data.get('query_type', 'UNKNOWN') + value = raw_data.get('value', target) + + # For non-infrastructure DNS records, store the full content + if record_type in ['TXT', 'SPF', 'CAA']: + dns_entry = f"{record_type}: {value}" + else: + dns_entry = f"{record_type}: {value}" + + if dns_entry not in metadata['dns_records']: + metadata['dns_records'].append(dns_entry) + + elif provider_name == 'crtsh': + if rel_type == RelationshipType.SAN_CERTIFICATE: + # Handle certificate data storage on domain nodes + domain_certs = raw_data.get('domain_certificates', {}) + + # Store certificate information for this domain + if node_id in domain_certs: + cert_summary = domain_certs[node_id] + + # Update domain metadata with certificate information + metadata['certificate_data'] = cert_summary + metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False) + + # Add related domains from shared certificates + if target not in metadata.get('related_domains_san', []): + if 'related_domains_san' not in metadata: + metadata['related_domains_san'] = [] + metadata['related_domains_san'].append(target) + + # Store shared certificate details for forensic analysis + shared_certs = raw_data.get('shared_certificates', []) + if shared_certs and 'shared_certificate_details' not in metadata: + metadata['shared_certificate_details'] = shared_certs + + elif provider_name == 'shodan': + # Merge Shodan data (avoid overwriting) + for key, value in raw_data.items(): + if key not in metadata['shodan'] or not metadata['shodan'][key]: + metadata['shodan'][key] = value + + elif provider_name == 'virustotal': + # Merge VirusTotal data + for key, value in raw_data.items(): + if key not in metadata['virustotal'] or not metadata['virustotal'][key]: + metadata['virustotal'][key] = value + + # Add passive DNS entries + if rel_type == RelationshipType.PASSIVE_DNS: + passive_entry = f"Passive DNS: {target}" + if passive_entry not in metadata['passive_dns']: + metadata['passive_dns'].append(passive_entry) + def _handle_large_entity(self, source_domain: str, relationships: list, rel_type: RelationshipType, provider_name: str): """ Handles the creation of a large entity node when a threshold is exceeded. @@ -422,12 +613,24 @@ class Scanner: def _query_providers_for_ip(self, ip: str) -> None: """ - Query all enabled providers for information about an IP address. + Query all enabled providers for information about an IP address and collect comprehensive metadata. """ print(f"Querying {len(self.providers)} providers for IP: {ip}") if not self.providers or self.stop_event.is_set(): return + # Comprehensive metadata collection for this IP + ip_metadata = { + 'dns_records': [], + 'passive_dns': [], + 'shodan': {}, + 'virustotal': {}, + 'asn_data': {}, + 'hostnames': [], + } + + all_relationships = [] # Store relationships with provider info + with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor: future_to_provider = { provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider @@ -441,20 +644,86 @@ class Scanner: try: relationships = future.result() print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}") + for source, target, rel_type, confidence, raw_data in relationships: - if self._is_valid_domain(target): - target_node_type = NodeType.DOMAIN - elif target.startswith('AS'): - target_node_type = NodeType.ASN - else: - target_node_type = NodeType.IP - self.graph.add_node(source, NodeType.IP) - self.graph.add_node(target, target_node_type) - if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data): - print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})") + # Add provider info to the relationship + enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name()) + all_relationships.append(enhanced_rel) + + # Collect metadata for the IP + self._collect_ip_metadata(ip, provider.get_name(), rel_type, target, raw_data, ip_metadata) + except (Exception, CancelledError) as e: print(f"Provider {provider.get_name()} failed for IP {ip}: {e}") + # Update the IP node with comprehensive metadata + self.graph.add_node(ip, NodeType.IP, metadata=ip_metadata) + + # Process relationships with correct provider attribution + for source, target, rel_type, confidence, raw_data, provider_name in all_relationships: + # Determine target node type + if _is_valid_domain(target): + target_node_type = NodeType.DOMAIN + elif target.startswith('AS'): + target_node_type = NodeType.ASN + else: + target_node_type = NodeType.IP + + # Create/update target node + self.graph.add_node(target, target_node_type) + + # Add relationship with correct provider attribution + if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data): + print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name}) from {provider_name}") + + def _collect_ip_metadata(self, ip: str, provider_name: str, rel_type: RelationshipType, + target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None: + """ + Collect and organize metadata for an IP node based on provider responses. + """ + if provider_name == 'dns': + if rel_type == RelationshipType.PTR_RECORD: + reverse_entry = f"PTR: {target}" + if reverse_entry not in metadata['dns_records']: + metadata['dns_records'].append(reverse_entry) + if target not in metadata['hostnames']: + metadata['hostnames'].append(target) + + elif provider_name == 'shodan': + # Merge Shodan data + for key, value in raw_data.items(): + if key not in metadata['shodan'] or not metadata['shodan'][key]: + metadata['shodan'][key] = value + + # Collect hostname information + if 'hostname' in raw_data and raw_data['hostname'] not in metadata['hostnames']: + metadata['hostnames'].append(raw_data['hostname']) + if 'hostnames' in raw_data: + for hostname in raw_data['hostnames']: + if hostname not in metadata['hostnames']: + metadata['hostnames'].append(hostname) + + elif provider_name == 'virustotal': + # Merge VirusTotal data + for key, value in raw_data.items(): + if key not in metadata['virustotal'] or not metadata['virustotal'][key]: + metadata['virustotal'][key] = value + + # Add passive DNS entries + if rel_type == RelationshipType.PASSIVE_DNS: + passive_entry = f"Passive DNS: {target}" + if passive_entry not in metadata['passive_dns']: + metadata['passive_dns'].append(passive_entry) + + # Handle ASN relationships + if rel_type == RelationshipType.ASN_MEMBERSHIP: + metadata['asn_data'] = { + 'asn': target, + 'description': raw_data.get('org', ''), + 'isp': raw_data.get('isp', ''), + 'country': raw_data.get('country', '') + } + def _safe_provider_query_domain(self, provider, domain: str): """Safely query provider for domain with error handling.""" @@ -478,12 +747,33 @@ class Scanner: def stop_scan(self) -> bool: """ - Request scan termination. + Request immediate scan termination with aggressive cancellation. """ try: if self.status == ScanStatus.RUNNING: + print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") + + # Signal all threads to stop self.stop_event.set() - print("Scan stop requested") + + # Close HTTP sessions in all providers to terminate ongoing requests + for provider in self.providers: + try: + if hasattr(provider, 'session'): + provider.session.close() + print(f"Closed HTTP session for provider: {provider.get_name()}") + except Exception as e: + print(f"Error closing session for {provider.get_name()}: {e}") + + # Shutdown executor immediately with cancel_futures=True + if self.executor: + print("Shutting down executor with immediate cancellation...") + self.executor.shutdown(wait=False, cancel_futures=True) + + # Give threads a moment to respond to cancellation, then force status change + threading.Timer(2.0, self._force_stop_completion).start() + + print("Immediate termination requested - ongoing requests will be cancelled") return True print("No active scan to stop") return False @@ -492,6 +782,13 @@ class Scanner: traceback.print_exc() return False + def _force_stop_completion(self): + """Force completion of stop operation after timeout.""" + if self.status == ScanStatus.RUNNING: + print("Forcing scan termination after timeout") + self.status = ScanStatus.STOPPED + self.logger.log_scan_complete() + def get_scan_status(self) -> Dict[str, Any]: """ Get current scan status and progress. @@ -561,16 +858,6 @@ class Scanner: } return export_data - def remove_provider(self, provider_name: str) -> bool: - """ - Remove a provider from the scanner. - """ - for i, provider in enumerate(self.providers): - if provider.get_name() == provider_name: - self.providers.pop(i) - return True - return False - def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: """ Get statistics for all providers. @@ -578,53 +865,4 @@ class Scanner: stats = {} for provider in self.providers: stats[provider.get_name()] = provider.get_statistics() - return stats - - def _is_valid_domain(self, domain: str) -> bool: - """ - Basic domain validation. - """ - if not domain or len(domain) > 253: - return False - parts = domain.split('.') - if len(parts) < 2: - return False - for part in parts: - if not part or len(part) > 63: - return False - if not part.replace('-', '').replace('_', '').isalnum(): - return False - return True - - def _is_valid_ip(self, ip: str) -> bool: - """ - Basic IP address validation. - """ - try: - parts = ip.split('.') - if len(parts) != 4: - return False - for part in parts: - num = int(part) - if not 0 <= num <= 255: - return False - return True - except (ValueError, AttributeError): - return False - - -class ScannerProxy: - def __init__(self): - self._scanner = None - print("ScannerProxy initialized") - - def __getattr__(self, name): - if self._scanner is None: - print("Creating new Scanner instance...") - self._scanner = Scanner() - print("Scanner instance created") - return getattr(self._scanner, name) - - -# Global scanner instance -scanner = ScannerProxy() \ No newline at end of file + return stats \ No newline at end of file diff --git a/core/session_config.py b/core/session_config.py new file mode 100644 index 0000000..dbf698e --- /dev/null +++ b/core/session_config.py @@ -0,0 +1,126 @@ +""" +Per-session configuration management for DNSRecon. +Provides isolated configuration instances for each user session. +""" + +import os +from typing import Dict, Optional + + +class SessionConfig: + """ + Session-specific configuration that inherits from global config + but maintains isolated API keys and provider settings. + """ + + def __init__(self): + """Initialize session config with global defaults.""" + # Copy all attributes from global config + self.api_keys: Dict[str, Optional[str]] = { + 'shodan': None, + 'virustotal': None + } + + # Default settings (copied from global config) + self.default_recursion_depth = 2 + self.default_timeout = 30 + self.max_concurrent_requests = 5 + self.large_entity_threshold = 100 + + # Rate limiting settings (per session) + self.rate_limits = { + 'crtsh': 60, + 'virustotal': 4, + 'shodan': 60, + 'dns': 100 + } + + # Provider settings (per session) + self.enabled_providers = { + 'crtsh': True, + 'dns': True, + 'virustotal': False, + 'shodan': False + } + + # Logging configuration + self.log_level = 'INFO' + self.log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + + # Flask configuration (shared) + self.flask_host = '127.0.0.1' + self.flask_port = 5000 + self.flask_debug = True + + def set_api_key(self, provider: str, api_key: str) -> bool: + """ + Set API key for a provider in this session. + + Args: + provider: Provider name (shodan, virustotal) + api_key: API key string + + Returns: + bool: True if key was set successfully + """ + if provider in self.api_keys: + self.api_keys[provider] = api_key + self.enabled_providers[provider] = True if api_key else False + return True + return False + + def get_api_key(self, provider: str) -> Optional[str]: + """ + Get API key for a provider in this session. + + Args: + provider: Provider name + + Returns: + API key or None if not set + """ + return self.api_keys.get(provider) + + def is_provider_enabled(self, provider: str) -> bool: + """ + Check if a provider is enabled in this session. + + Args: + provider: Provider name + + Returns: + bool: True if provider is enabled + """ + return self.enabled_providers.get(provider, False) + + def get_rate_limit(self, provider: str) -> int: + """ + Get rate limit for a provider in this session. + + Args: + provider: Provider name + + Returns: + Rate limit in requests per minute + """ + return self.rate_limits.get(provider, 60) + + def load_from_env(self): + """Load configuration from environment variables (only if not already set).""" + if os.getenv('VIRUSTOTAL_API_KEY') and not self.api_keys['virustotal']: + self.set_api_key('virustotal', os.getenv('VIRUSTOTAL_API_KEY')) + + if os.getenv('SHODAN_API_KEY') and not self.api_keys['shodan']: + self.set_api_key('shodan', os.getenv('SHODAN_API_KEY')) + + # Override default settings from environment + self.default_recursion_depth = int(os.getenv('DEFAULT_RECURSION_DEPTH', '2')) + self.default_timeout = 30 + self.max_concurrent_requests = 5 + + +def create_session_config() -> SessionConfig: + """Create a new session configuration instance.""" + session_config = SessionConfig() + session_config.load_from_env() + return session_config \ No newline at end of file diff --git a/core/session_manager.py b/core/session_manager.py new file mode 100644 index 0000000..88dd4fa --- /dev/null +++ b/core/session_manager.py @@ -0,0 +1,281 @@ +""" +Session manager for DNSRecon multi-user support. +Manages individual scanner instances per user session with automatic cleanup. +""" + +import threading +import time +import uuid +from typing import Dict, Optional, Any +from datetime import datetime, timezone + +from core.scanner import Scanner + + +class SessionManager: + """ + Manages multiple scanner instances for concurrent user sessions. + Provides session isolation and automatic cleanup of inactive sessions. + """ + + def __init__(self, session_timeout_minutes: int = 60): + """ + Initialize session manager. + + Args: + session_timeout_minutes: Minutes of inactivity before session cleanup + """ + self.sessions: Dict[str, Dict[str, Any]] = {} + self.session_timeout = session_timeout_minutes * 60 # Convert to seconds + self.lock = threading.Lock() + + # Start cleanup thread + self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) + self.cleanup_thread.start() + + print(f"SessionManager initialized with {session_timeout_minutes}min timeout") + + def create_session(self) -> str: + """ + Create a new user session with dedicated scanner instance and configuration. + Enhanced with better debugging and race condition protection. + + Returns: + Unique session ID + """ + session_id = str(uuid.uuid4()) + + print(f"=== CREATING SESSION {session_id} ===") + + try: + # Create session-specific configuration + from core.session_config import create_session_config + session_config = create_session_config() + + print(f"Created session config for {session_id}") + + # Create scanner with session config + from core.scanner import Scanner + scanner_instance = Scanner(session_config=session_config) + + print(f"Created scanner instance {id(scanner_instance)} for session {session_id}") + print(f"Initial scanner status: {scanner_instance.status}") + + with self.lock: + self.sessions[session_id] = { + 'scanner': scanner_instance, + 'config': session_config, + 'created_at': time.time(), + 'last_activity': time.time(), + 'user_agent': '', + 'status': 'active' + } + + print(f"Session {session_id} stored in session manager") + print(f"Total active sessions: {len([s for s in self.sessions.values() if s['status'] == 'active'])}") + print(f"=== SESSION {session_id} CREATED SUCCESSFULLY ===") + + return session_id + + except Exception as e: + print(f"ERROR: Failed to create session {session_id}: {e}") + raise + + def get_session(self, session_id: str) -> Optional[object]: + """ + Get scanner instance for a session with enhanced debugging. + + Args: + session_id: Session identifier + + Returns: + Scanner instance or None if session doesn't exist + """ + if not session_id: + print("get_session called with empty session_id") + return None + + with self.lock: + if session_id not in self.sessions: + print(f"Session {session_id} not found in session manager") + print(f"Available sessions: {list(self.sessions.keys())}") + return None + + session_data = self.sessions[session_id] + + # Check if session is still active + if session_data['status'] != 'active': + print(f"Session {session_id} is not active (status: {session_data['status']})") + return None + + # Update last activity + session_data['last_activity'] = time.time() + scanner = session_data['scanner'] + + print(f"Retrieved scanner {id(scanner)} for session {session_id}") + print(f"Scanner status: {scanner.status}") + + return scanner + + def get_or_create_session(self, session_id: Optional[str] = None) -> tuple[str, Scanner]: + """ + Get existing session or create new one. + + Args: + session_id: Optional existing session ID + + Returns: + Tuple of (session_id, scanner_instance) + """ + if session_id and self.get_session(session_id): + return session_id, self.get_session(session_id) + else: + new_session_id = self.create_session() + return new_session_id, self.get_session(new_session_id) + + def terminate_session(self, session_id: str) -> bool: + """ + Terminate a specific session and cleanup resources. + + Args: + session_id: Session to terminate + + Returns: + True if session was terminated successfully + """ + with self.lock: + if session_id not in self.sessions: + return False + + session_data = self.sessions[session_id] + scanner = session_data['scanner'] + + # Stop any running scan + try: + if scanner.status == 'running': + scanner.stop_scan() + print(f"Stopped scan for session: {session_id}") + except Exception as e: + print(f"Error stopping scan for session {session_id}: {e}") + + # Mark as terminated + session_data['status'] = 'terminated' + session_data['terminated_at'] = time.time() + + # Remove from active sessions after a brief delay to allow cleanup + threading.Timer(5.0, lambda: self._remove_session(session_id)).start() + + print(f"Terminated session: {session_id}") + return True + + def _remove_session(self, session_id: str) -> None: + """Remove session from memory.""" + with self.lock: + if session_id in self.sessions: + del self.sessions[session_id] + print(f"Removed session from memory: {session_id}") + + def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: + """ + Get session information without updating activity. + + Args: + session_id: Session identifier + + Returns: + Session information dictionary or None + """ + with self.lock: + if session_id not in self.sessions: + return None + + session_data = self.sessions[session_id] + scanner = session_data['scanner'] + + return { + 'session_id': session_id, + 'created_at': datetime.fromtimestamp(session_data['created_at'], timezone.utc).isoformat(), + 'last_activity': datetime.fromtimestamp(session_data['last_activity'], timezone.utc).isoformat(), + 'status': session_data['status'], + 'scan_status': scanner.status, + 'current_target': scanner.current_target, + 'uptime_seconds': time.time() - session_data['created_at'] + } + + def list_active_sessions(self) -> Dict[str, Dict[str, Any]]: + """ + List all active sessions with enhanced debugging info. + + Returns: + Dictionary of session information + """ + active_sessions = {} + + with self.lock: + for session_id, session_data in self.sessions.items(): + if session_data['status'] == 'active': + scanner = session_data['scanner'] + active_sessions[session_id] = { + 'session_id': session_id, + 'created_at': datetime.fromtimestamp(session_data['created_at'], timezone.utc).isoformat(), + 'last_activity': datetime.fromtimestamp(session_data['last_activity'], timezone.utc).isoformat(), + 'status': session_data['status'], + 'scan_status': scanner.status, + 'current_target': scanner.current_target, + 'uptime_seconds': time.time() - session_data['created_at'], + 'scanner_object_id': id(scanner) + } + + return active_sessions + + def _cleanup_loop(self) -> None: + """Background thread to cleanup inactive sessions.""" + while True: + try: + current_time = time.time() + sessions_to_cleanup = [] + + with self.lock: + for session_id, session_data in self.sessions.items(): + if session_data['status'] != 'active': + continue + + inactive_time = current_time - session_data['last_activity'] + + if inactive_time > self.session_timeout: + sessions_to_cleanup.append(session_id) + + # Cleanup outside of lock to avoid deadlock + for session_id in sessions_to_cleanup: + print(f"Cleaning up inactive session: {session_id}") + self.terminate_session(session_id) + + # Sleep for 5 minutes between cleanup cycles + time.sleep(300) + + except Exception as e: + print(f"Error in session cleanup loop: {e}") + time.sleep(60) # Sleep for 1 minute on error + + def get_statistics(self) -> Dict[str, Any]: + """ + Get session manager statistics. + + Returns: + Statistics dictionary + """ + with self.lock: + active_count = sum(1 for s in self.sessions.values() if s['status'] == 'active') + running_scans = sum(1 for s in self.sessions.values() + if s['status'] == 'active' and s['scanner'].status == 'running') + + return { + 'total_sessions': len(self.sessions), + 'active_sessions': active_count, + 'running_scans': running_scans, + 'session_timeout_minutes': self.session_timeout / 60 + } + + +# Global session manager instance +session_manager = SessionManager(session_timeout_minutes=60) \ No newline at end of file diff --git a/providers/base_provider.py b/providers/base_provider.py index e291029..f40d9ea 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -7,10 +7,9 @@ import os import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Tuple -from datetime import datetime from core.logger import get_forensic_logger -from core.graph_manager import NodeType, RelationshipType +from core.graph_manager import RelationshipType class RateLimiter: @@ -42,36 +41,52 @@ class RateLimiter: class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. - Provides common functionality and defines the provider interface. + Now supports session-specific configuration. """ - def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30): + def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): """ - Initialize base provider. + Initialize base provider with session-specific configuration. Args: name: Provider name for logging - rate_limit: Requests per minute limit + rate_limit: Requests per minute limit (default override) timeout: Request timeout in seconds + session_config: Session-specific configuration """ + # Use session config if provided, otherwise fall back to global config + if session_config is not None: + self.config = session_config + actual_rate_limit = self.config.get_rate_limit(name) + actual_timeout = self.config.default_timeout + else: + # Fallback to global config for backwards compatibility + from config import config as global_config + self.config = global_config + actual_rate_limit = rate_limit + actual_timeout = timeout + self.name = name - self.rate_limiter = RateLimiter(rate_limit) - self.timeout = timeout + self.rate_limiter = RateLimiter(actual_rate_limit) + self.timeout = actual_timeout self._local = threading.local() self.logger = get_forensic_logger() + self._stop_event = None - # Caching configuration - self.cache_dir = '.cache' + # Caching configuration (per session) + self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config self.cache_expiry = 12 * 3600 # 12 hours in seconds if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) - # Statistics + # Statistics (per provider instance) self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 + print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)") + @property def session(self): if not hasattr(self._local, 'session'): @@ -118,136 +133,174 @@ class BaseProvider(ABC): pass def make_request(self, url: str, method: str = "GET", - params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - target_indicator: str = "", - max_retries: int = 3) -> Optional[requests.Response]: - """ - Make a rate-limited HTTP request with forensic logging and retry logic. + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + target_indicator: str = "", + max_retries: int = 3) -> Optional[requests.Response]: + """ + Make a rate-limited HTTP request with forensic logging and retry logic. + Now supports cancellation via stop_event from scanner. + """ + # Check for cancellation before starting + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + print(f"Request cancelled before start: {url}") + return None - Args: - url: Request URL - method: HTTP method - params: Query parameters - headers: Additional headers - target_indicator: The indicator being investigated - max_retries: Maximum number of retry attempts + # Create a unique cache key + cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" + cache_path = os.path.join(self.cache_dir, cache_key) - Returns: - Response object or None if request failed - """ - # Create a unique cache key - cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" - cache_path = os.path.join(self.cache_dir, cache_key) - - # Check cache - if os.path.exists(cache_path): - cache_age = time.time() - os.path.getmtime(cache_path) - if cache_age < self.cache_expiry: - print(f"Returning cached response for: {url}") - with open(cache_path, 'r') as f: - cached_data = json.load(f) - response = requests.Response() - response.status_code = cached_data['status_code'] - response._content = cached_data['content'].encode('utf-8') - response.headers = cached_data['headers'] - return response - - for attempt in range(max_retries + 1): - # Apply rate limiting - self.rate_limiter.wait_if_needed() - - start_time = time.time() - response = None - error = None - - try: - self.total_requests += 1 - - # Prepare request - request_headers = self.session.headers.copy() - if headers: - request_headers.update(headers) - - print(f"Making {method} request to: {url} (attempt {attempt + 1})") - - # Make request - if method.upper() == "GET": - response = self.session.get( - url, - params=params, - headers=request_headers, - timeout=self.timeout - ) - elif method.upper() == "POST": - response = self.session.post( - url, - json=params, - headers=request_headers, - timeout=self.timeout - ) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - print(f"Response status: {response.status_code}") - response.raise_for_status() - self.successful_requests += 1 - - # Success - log, cache, and return - duration_ms = (time.time() - start_time) * 1000 - self.logger.log_api_request( - provider=self.name, - url=url, - method=method.upper(), - status_code=response.status_code, - response_size=len(response.content), - duration_ms=duration_ms, - error=None, - target_indicator=target_indicator - ) - # Cache the successful response to disk - with open(cache_path, 'w') as f: - json.dump({ - 'status_code': response.status_code, - 'content': response.text, - 'headers': dict(response.headers) - }, f) + # Check cache + if os.path.exists(cache_path): + cache_age = time.time() - os.path.getmtime(cache_path) + if cache_age < self.cache_expiry: + print(f"Returning cached response for: {url}") + with open(cache_path, 'r') as f: + cached_data = json.load(f) + response = requests.Response() + response.status_code = cached_data['status_code'] + response._content = cached_data['content'].encode('utf-8') + response.headers = cached_data['headers'] return response - except requests.exceptions.RequestException as e: - error = str(e) - self.failed_requests += 1 - print(f"Request failed (attempt {attempt + 1}): {error}") - - # Check if we should retry - if attempt < max_retries and self._should_retry(e): - backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s - print(f"Retrying in {backoff_time} seconds...") - time.sleep(backoff_time) - continue - else: - break + for attempt in range(max_retries + 1): + # Check for cancellation before each attempt + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + print(f"Request cancelled during attempt {attempt + 1}: {url}") + return None - except Exception as e: - error = f"Unexpected error: {str(e)}" - self.failed_requests += 1 - print(f"Unexpected error: {error}") + # Apply rate limiting (but reduce wait time if cancellation is requested) + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + break + + self.rate_limiter.wait_if_needed() + + # Check again after rate limiting + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + print(f"Request cancelled after rate limiting: {url}") + return None + + start_time = time.time() + response = None + error = None + + try: + self.total_requests += 1 + + # Prepare request + request_headers = self.session.headers.copy() + if headers: + request_headers.update(headers) + + print(f"Making {method} request to: {url} (attempt {attempt + 1})") + + # Use shorter timeout if termination is requested + request_timeout = self.timeout + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested + + # Make request + if method.upper() == "GET": + response = self.session.get( + url, + params=params, + headers=request_headers, + timeout=request_timeout + ) + elif method.upper() == "POST": + response = self.session.post( + url, + json=params, + headers=request_headers, + timeout=request_timeout + ) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + print(f"Response status: {response.status_code}") + response.raise_for_status() + self.successful_requests += 1 + + # Success - log, cache, and return + duration_ms = (time.time() - start_time) * 1000 + self.logger.log_api_request( + provider=self.name, + url=url, + method=method.upper(), + status_code=response.status_code, + response_size=len(response.content), + duration_ms=duration_ms, + error=None, + target_indicator=target_indicator + ) + # Cache the successful response to disk + with open(cache_path, 'w') as f: + json.dump({ + 'status_code': response.status_code, + 'content': response.text, + 'headers': dict(response.headers) + }, f) + return response + + except requests.exceptions.RequestException as e: + error = str(e) + self.failed_requests += 1 + print(f"Request failed (attempt {attempt + 1}): {error}") + + # Check for cancellation before retrying + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + print(f"Request cancelled, not retrying: {url}") + break + + # Check if we should retry + if attempt < max_retries and self._should_retry(e): + backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s + print(f"Retrying in {backoff_time} seconds...") + + # Shorter backoff if termination is requested + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + backoff_time = min(0.5, backoff_time) + + # Sleep with cancellation checking + sleep_start = time.time() + while time.time() - sleep_start < backoff_time: + if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): + print(f"Request cancelled during backoff: {url}") + return None + time.sleep(0.1) # Check every 100ms + continue + else: break - # All attempts failed - log and return None - duration_ms = (time.time() - start_time) * 1000 - self.logger.log_api_request( - provider=self.name, - url=url, - method=method.upper(), - status_code=response.status_code if response else None, - response_size=len(response.content) if response else None, - duration_ms=duration_ms, - error=error, - target_indicator=target_indicator - ) - - return None + except Exception as e: + error = f"Unexpected error: {str(e)}" + self.failed_requests += 1 + print(f"Unexpected error: {error}") + break + + # All attempts failed - log and return None + duration_ms = (time.time() - start_time) * 1000 + self.logger.log_api_request( + provider=self.name, + url=url, + method=method.upper(), + status_code=response.status_code if response else None, + response_size=len(response.content) if response else None, + duration_ms=duration_ms, + error=error, + target_indicator=target_indicator + ) + + return None + + def set_stop_event(self, stop_event: threading.Event) -> None: + """ + Set the stop event for this provider to enable cancellation. + + Args: + stop_event: Threading event to signal cancellation + """ + self._stop_event = stop_event def _should_retry(self, exception: requests.exceptions.RequestException) -> bool: """ @@ -314,90 +367,4 @@ class BaseProvider(ABC): 'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0, 'relationships_found': self.total_relationships_found, 'rate_limit': self.rate_limiter.requests_per_minute - } - - def reset_statistics(self) -> None: - """Reset provider statistics.""" - self.total_requests = 0 - self.successful_requests = 0 - self.failed_requests = 0 - self.total_relationships_found = 0 - - def _extract_domain_from_url(self, url: str) -> Optional[str]: - """ - Extract domain from URL. - - Args: - url: URL string - - Returns: - Domain name or None if extraction fails - """ - try: - # Remove protocol - if '://' in url: - url = url.split('://', 1)[1] - - # Remove path - if '/' in url: - url = url.split('/', 1)[0] - - # Remove port - if ':' in url: - url = url.split(':', 1)[0] - - return url.lower() - - except Exception: - return None - - def _is_valid_domain(self, domain: str) -> bool: - """ - Basic domain validation. - - Args: - domain: Domain string to validate - - Returns: - True if domain appears valid - """ - if not domain or len(domain) > 253: - return False - - # Check for valid characters and structure - parts = domain.split('.') - if len(parts) < 2: - return False - - for part in parts: - if not part or len(part) > 63: - return False - if not part.replace('-', '').replace('_', '').isalnum(): - return False - - return True - - def _is_valid_ip(self, ip: str) -> bool: - """ - Basic IP address validation. - - Args: - ip: IP address string to validate - - Returns: - True if IP appears valid - """ - try: - parts = ip.split('.') - if len(parts) != 4: - return False - - for part in parts: - num = int(part) - if not 0 <= num <= 255: - return False - - return True - - except (ValueError, AttributeError): - return False \ No newline at end of file + } \ No newline at end of file diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index f7224ad..fc977f2 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -1,6 +1,7 @@ """ Certificate Transparency provider using crt.sh. -Discovers domain relationships through certificate SAN analysis. +Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking. +Stores certificates as metadata on domain nodes rather than creating certificate nodes. """ import json @@ -10,23 +11,26 @@ from urllib.parse import quote from datetime import datetime, timezone from .base_provider import BaseProvider +from utils.helpers import _is_valid_domain from core.graph_manager import RelationshipType class CrtShProvider(BaseProvider): """ Provider for querying crt.sh certificate transparency database. - Discovers domain relationships through certificate Subject Alternative Names (SANs). + Now uses session-specific configuration and caching. """ - def __init__(self): - """Initialize CrtSh provider with appropriate rate limiting.""" + def __init__(self, session_config=None): + """Initialize CrtSh provider with session-specific configuration.""" super().__init__( name="crtsh", - rate_limit=60, # Be respectful to the free service - timeout=30 + rate_limit=60, + timeout=15, + session_config=session_config ) self.base_url = "https://crt.sh/" + self._stop_event = None def get_name(self) -> str: """Return the provider name.""" @@ -40,31 +44,128 @@ class CrtShProvider(BaseProvider): """ return True + def _parse_certificate_date(self, date_string: str) -> datetime: + """ + Parse certificate date from crt.sh format. + + Args: + date_string: Date string from crt.sh API + + Returns: + Parsed datetime object in UTC + """ + if not date_string: + raise ValueError("Empty date string") + + try: + # Handle various possible formats from crt.sh + if date_string.endswith('Z'): + return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc) + elif '+' in date_string or date_string.endswith('UTC'): + # Handle timezone-aware strings + date_string = date_string.replace('UTC', '').strip() + if '+' in date_string: + date_string = date_string.split('+')[0] + return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc) + else: + # Assume UTC if no timezone specified + return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc) + except Exception as e: + # Fallback: try parsing without timezone info and assume UTC + try: + return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) + except Exception: + raise ValueError(f"Unable to parse date: {date_string}") from e + def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool: - """Check if a certificate is currently valid.""" + """ + Check if a certificate is currently valid based on its expiry date. + + Args: + cert_data: Certificate data from crt.sh + + Returns: + True if certificate is currently valid (not expired) + """ try: not_after_str = cert_data.get('not_after') - if not_after_str: - # Append 'Z' to indicate UTC if it's not present - if not not_after_str.endswith('Z'): - not_after_str += 'Z' - not_after_date = datetime.fromisoformat(not_after_str.replace('Z', '+00:00')) - return not_after_date > datetime.now(timezone.utc) - except Exception: + if not not_after_str: + return False + + not_after_date = self._parse_certificate_date(not_after_str) + not_before_str = cert_data.get('not_before') + + now = datetime.now(timezone.utc) + + # Check if certificate is within valid date range + is_not_expired = not_after_date > now + + if not_before_str: + not_before_date = self._parse_certificate_date(not_before_str) + is_not_before_valid = not_before_date <= now + return is_not_expired and is_not_before_valid + + return is_not_expired + + except Exception as e: + self.logger.logger.debug(f"Certificate validity check failed: {e}") return False - return False + + def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Extract comprehensive metadata from certificate data. + + Args: + cert_data: Raw certificate data from crt.sh + + Returns: + Comprehensive certificate metadata dictionary + """ + metadata = { + 'certificate_id': cert_data.get('id'), + 'serial_number': cert_data.get('serial_number'), + 'issuer_name': cert_data.get('issuer_name'), + 'issuer_ca_id': cert_data.get('issuer_ca_id'), + 'common_name': cert_data.get('common_name'), + 'not_before': cert_data.get('not_before'), + 'not_after': cert_data.get('not_after'), + 'entry_timestamp': cert_data.get('entry_timestamp'), + 'source': 'crt.sh' + } + + # Add computed fields + try: + if metadata['not_before'] and metadata['not_after']: + not_before = self._parse_certificate_date(metadata['not_before']) + not_after = self._parse_certificate_date(metadata['not_after']) + + metadata['validity_period_days'] = (not_after - not_before).days + metadata['is_currently_valid'] = self._is_cert_valid(cert_data) + metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30 + + # Add human-readable dates + metadata['not_before_formatted'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC') + metadata['not_after_formatted'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC') + + except Exception as e: + self.logger.logger.debug(f"Error computing certificate metadata: {e}") + metadata['is_currently_valid'] = False + metadata['expires_soon'] = False + + return metadata def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query crt.sh for certificates containing the domain. - - Args: - domain: Domain to investigate - - Returns: - List of relationships discovered from certificate analysis + Creates domain-to-domain relationships and stores certificate data as metadata. + Now supports early termination via stop_event. """ - if not self._is_valid_domain(domain): + if not _is_valid_domain(domain): + return [] + + # Check for cancellation before starting + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh query cancelled before start for domain: {domain}") return [] relationships = [] @@ -72,56 +173,113 @@ class CrtShProvider(BaseProvider): try: # Query crt.sh for certificates url = f"{self.base_url}?q={quote(domain)}&output=json" - response = self.make_request(url, target_indicator=domain) + response = self.make_request(url, target_indicator=domain, max_retries=1) # Reduce retries for faster cancellation if not response or response.status_code != 200: return [] + # Check for cancellation after request + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh query cancelled after request for domain: {domain}") + return [] + certificates = response.json() if not certificates: return [] - # Process certificates to extract relationships - discovered_subdomains = {} + # Check for cancellation before processing + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh query cancelled before processing for domain: {domain}") + return [] - for cert_data in certificates: + # Aggregate certificate data by domain + domain_certificates = {} + all_discovered_domains = set() + + # Process certificates and group by domain (with cancellation checks) + for i, cert_data in enumerate(certificates): + # Check for cancellation every 10 certificates + if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): + print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}") + break + + cert_metadata = self._extract_certificate_metadata(cert_data) cert_domains = self._extract_domains_from_certificate(cert_data) - is_valid = self._is_cert_valid(cert_data) + + # Add all domains from this certificate to our tracking + for cert_domain in cert_domains: + if not _is_valid_domain(cert_domain): + continue + + all_discovered_domains.add(cert_domain) + + # Initialize domain certificate list if needed + if cert_domain not in domain_certificates: + domain_certificates[cert_domain] = [] + + # Add this certificate to the domain's certificate list + domain_certificates[cert_domain].append(cert_metadata) + + # Final cancellation check before creating relationships + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh query cancelled before relationship creation for domain: {domain}") + return [] - for subdomain in cert_domains: - if self._is_valid_domain(subdomain) and subdomain != domain: - if subdomain not in discovered_subdomains: - discovered_subdomains[subdomain] = {'has_valid_cert': False, 'issuers': set()} - - if is_valid: - discovered_subdomains[subdomain]['has_valid_cert'] = True - - issuer = cert_data.get('issuer_name') - if issuer: - discovered_subdomains[subdomain]['issuers'].add(issuer) + # Create relationships from query domain to ALL discovered domains + for discovered_domain in all_discovered_domains: + if discovered_domain == domain: + continue # Skip self-relationships + + # Check for cancellation during relationship creation + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh relationship creation cancelled for domain: {domain}") + break - # Create relationships from the discovered subdomains - for subdomain, data in discovered_subdomains.items(): - raw_data = { - 'has_valid_cert': data['has_valid_cert'], - 'issuers': list(data['issuers']), - 'source': 'crt.sh' + if not _is_valid_domain(discovered_domain): + continue + + # Get certificates for both domains + query_domain_certs = domain_certificates.get(domain, []) + discovered_domain_certs = domain_certificates.get(discovered_domain, []) + + # Find shared certificates (for metadata purposes) + shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs) + + # Calculate confidence based on relationship type and shared certificates + confidence = self._calculate_domain_relationship_confidence( + domain, discovered_domain, shared_certificates, all_discovered_domains + ) + + # Create comprehensive raw data for the relationship + relationship_raw_data = { + 'relationship_type': 'certificate_discovery', + 'shared_certificates': shared_certificates, + 'total_shared_certs': len(shared_certificates), + 'discovery_context': self._determine_relationship_context(discovered_domain, domain), + 'domain_certificates': { + domain: self._summarize_certificates(query_domain_certs), + discovered_domain: self._summarize_certificates(discovered_domain_certs) + } } + + # Create domain -> domain relationship relationships.append(( domain, - subdomain, + discovered_domain, RelationshipType.SAN_CERTIFICATE, - RelationshipType.SAN_CERTIFICATE.default_confidence, - raw_data + confidence, + relationship_raw_data )) + + # Log the relationship discovery self.log_relationship_discovery( source_node=domain, - target_node=subdomain, + target_node=discovered_domain, relationship_type=RelationshipType.SAN_CERTIFICATE, - confidence_score=RelationshipType.SAN_CERTIFICATE.default_confidence, - raw_data=raw_data, - discovery_method="certificate_san_analysis" + confidence_score=confidence, + raw_data=relationship_raw_data, + discovery_method="certificate_transparency_analysis" ) except json.JSONDecodeError as e: @@ -130,6 +288,165 @@ class CrtShProvider(BaseProvider): self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}") return relationships + + def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Find certificates that are shared between two domain certificate lists. + + Args: + certs1: First domain's certificates + certs2: Second domain's certificates + + Returns: + List of shared certificate metadata + """ + shared = [] + + # Create a set of certificate IDs from the first list for quick lookup + cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')} + + # Find certificates in the second list that match + for cert in certs2: + if cert.get('certificate_id') in cert1_ids: + shared.append(cert) + + return shared + + def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Create a summary of certificates for a domain. + + Args: + certificates: List of certificate metadata + + Returns: + Summary dictionary with aggregate statistics + """ + if not certificates: + return { + 'total_certificates': 0, + 'valid_certificates': 0, + 'expired_certificates': 0, + 'expires_soon_count': 0, + 'unique_issuers': [], + 'latest_certificate': None, + 'has_valid_cert': False + } + + valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid')) + expired_count = len(certificates) - valid_count + expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon')) + + # Get unique issuers + unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name'))) + + # Find the most recent certificate + latest_cert = None + latest_date = None + + for cert in certificates: + try: + if cert.get('not_before'): + cert_date = self._parse_certificate_date(cert['not_before']) + if latest_date is None or cert_date > latest_date: + latest_date = cert_date + latest_cert = cert + except Exception: + continue + + return { + 'total_certificates': len(certificates), + 'valid_certificates': valid_count, + 'expired_certificates': expired_count, + 'expires_soon_count': expires_soon_count, + 'unique_issuers': unique_issuers, + 'latest_certificate': latest_cert, + 'has_valid_cert': valid_count > 0, + 'certificate_details': certificates # Full details for forensic analysis + } + + def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str, + shared_certificates: List[Dict[str, Any]], + all_discovered_domains: Set[str]) -> float: + """ + Calculate confidence score for domain relationship based on various factors. + + Args: + domain1: Source domain (query domain) + domain2: Target domain (discovered domain) + shared_certificates: List of shared certificate metadata + all_discovered_domains: All domains discovered in this query + + Returns: + Confidence score between 0.0 and 1.0 + """ + base_confidence = RelationshipType.SAN_CERTIFICATE.default_confidence + + # Adjust confidence based on domain relationship context + relationship_context = self._determine_relationship_context(domain2, domain1) + + if relationship_context == 'exact_match': + context_bonus = 0.0 # This shouldn't happen, but just in case + elif relationship_context == 'subdomain': + context_bonus = 0.1 # High confidence for subdomains + elif relationship_context == 'parent_domain': + context_bonus = 0.05 # Medium confidence for parent domains + else: + context_bonus = 0.0 # Related domains get base confidence + + # Adjust confidence based on shared certificates + if shared_certificates: + shared_count = len(shared_certificates) + if shared_count >= 3: + shared_bonus = 0.1 + elif shared_count >= 2: + shared_bonus = 0.05 + else: + shared_bonus = 0.02 + + # Additional bonus for valid shared certificates + valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid')) + if valid_shared > 0: + validity_bonus = 0.05 + else: + validity_bonus = 0.0 + else: + # Even without shared certificates, domains found in the same query have some relationship + shared_bonus = 0.0 + validity_bonus = 0.0 + + # Adjust confidence based on certificate issuer reputation (if shared certificates exist) + issuer_bonus = 0.0 + if shared_certificates: + for cert in shared_certificates: + issuer = cert.get('issuer_name', '').lower() + if any(trusted_ca in issuer for trusted_ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']): + issuer_bonus = max(issuer_bonus, 0.03) + break + + # Calculate final confidence + final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus + return max(0.1, min(1.0, final_confidence)) # Clamp between 0.1 and 1.0 + + def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str: + """ + Determine the context of the relationship between certificate domain and query domain. + + Args: + cert_domain: Domain found in certificate + query_domain: Original query domain + + Returns: + String describing the relationship context + """ + if cert_domain == query_domain: + return 'exact_match' + elif cert_domain.endswith(f'.{query_domain}'): + return 'subdomain' + elif query_domain.endswith(f'.{cert_domain}'): + return 'parent_domain' + else: + return 'related_domain' def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ @@ -143,7 +460,6 @@ class CrtShProvider(BaseProvider): Empty list (crt.sh doesn't support IP-based certificate queries effectively) """ # crt.sh doesn't effectively support IP-based certificate queries - # This would require parsing certificate details for IP SANs, which is complex return [] def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: @@ -162,7 +478,7 @@ class CrtShProvider(BaseProvider): common_name = cert_data.get('common_name', '') if common_name: cleaned_cn = self._clean_domain_name(common_name) - if cleaned_cn and self._is_valid_domain(cleaned_cn): + if cleaned_cn and _is_valid_domain(cleaned_cn): domains.add(cleaned_cn) # Extract from name_value field (contains SANs) @@ -171,7 +487,7 @@ class CrtShProvider(BaseProvider): # Split by newlines and clean each domain for line in name_value.split('\n'): cleaned_domain = self._clean_domain_name(line.strip()) - if cleaned_domain and self._is_valid_domain(cleaned_domain): + if cleaned_domain and _is_valid_domain(cleaned_domain): domains.add(cleaned_domain) return domains @@ -215,70 +531,4 @@ class CrtShProvider(BaseProvider): if domain and not domain.startswith(('.', '-')) and not domain.endswith(('.', '-')): return domain - return "" - - def get_certificate_details(self, certificate_id: str) -> Dict[str, Any]: - """ - Get detailed information about a specific certificate. - - Args: - certificate_id: Certificate ID from crt.sh - - Returns: - Dictionary containing certificate details - """ - try: - url = f"{self.base_url}?id={certificate_id}&output=json" - response = self.make_request(url, target_indicator=f"cert_{certificate_id}") - - if response and response.status_code == 200: - return response.json() - - except Exception as e: - self.logger.logger.error(f"Error fetching certificate details for {certificate_id}: {e}") - - return {} - - def search_certificates_by_serial(self, serial_number: str) -> List[Dict[str, Any]]: - """ - Search for certificates by serial number. - - Args: - serial_number: Certificate serial number - - Returns: - List of matching certificates - """ - try: - url = f"{self.base_url}?serial={quote(serial_number)}&output=json" - response = self.make_request(url, target_indicator=f"serial_{serial_number}") - - if response and response.status_code == 200: - return response.json() - - except Exception as e: - self.logger.logger.error(f"Error searching certificates by serial {serial_number}: {e}") - - return [] - - def get_issuer_certificates(self, issuer_name: str) -> List[Dict[str, Any]]: - """ - Get certificates issued by a specific CA. - - Args: - issuer_name: Certificate Authority name - - Returns: - List of certificates from the specified issuer - """ - try: - url = f"{self.base_url}?issuer={quote(issuer_name)}&output=json" - response = self.make_request(url, target_indicator=f"issuer_{issuer_name}") - - if response and response.status_code == 200: - return response.json() - - except Exception as e: - self.logger.logger.error(f"Error fetching certificates for issuer {issuer_name}: {e}") - - return [] \ No newline at end of file + return "" \ No newline at end of file diff --git a/providers/dns_provider.py b/providers/dns_provider.py index cee2a84..f3806b3 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -1,25 +1,26 @@ # dnsrecon/providers/dns_provider.py -import socket import dns.resolver import dns.reversename -from typing import List, Dict, Any, Tuple, Optional +from typing import List, Dict, Any, Tuple from .base_provider import BaseProvider -from core.graph_manager import RelationshipType, NodeType +from utils.helpers import _is_valid_ip, _is_valid_domain +from core.graph_manager import RelationshipType class DNSProvider(BaseProvider): """ Provider for standard DNS resolution and reverse DNS lookups. - Discovers domain-to-IP and IP-to-domain relationships through DNS records. + Now uses session-specific configuration. """ - def __init__(self): - """Initialize DNS provider with appropriate rate limiting.""" + def __init__(self, session_config=None): + """Initialize DNS provider with session-specific configuration.""" super().__init__( name="dns", - rate_limit=100, # DNS queries can be faster - timeout=10 + rate_limit=100, + timeout=10, + session_config=session_config ) # Configure DNS resolver @@ -45,7 +46,7 @@ class DNSProvider(BaseProvider): Returns: List of relationships discovered from DNS analysis """ - if not self._is_valid_domain(domain): + if not _is_valid_domain(domain): return [] relationships = [] @@ -66,7 +67,7 @@ class DNSProvider(BaseProvider): Returns: List of relationships discovered from reverse DNS """ - if not self._is_valid_ip(ip): + if not _is_valid_ip(ip): return [] relationships = [] @@ -81,7 +82,7 @@ class DNSProvider(BaseProvider): for ptr_record in response: hostname = str(ptr_record).rstrip('.') - if self._is_valid_domain(hostname): + if _is_valid_domain(hostname): raw_data = { 'query_type': 'PTR', 'ip_address': ip, diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 3dd2c09..f41e8f8 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -4,38 +4,37 @@ Discovers IP relationships and infrastructure context through Shodan API. """ import json -from typing import List, Dict, Any, Tuple, Optional -from urllib.parse import quote +from typing import List, Dict, Any, Tuple from .base_provider import BaseProvider +from utils.helpers import _is_valid_ip, _is_valid_domain from core.graph_manager import RelationshipType -from config import config class ShodanProvider(BaseProvider): """ Provider for querying Shodan API for IP address and hostname information. - Requires valid API key and respects Shodan's rate limits. + Now uses session-specific API keys. """ - def __init__(self): - """Initialize Shodan provider with appropriate rate limiting.""" + def __init__(self, session_config=None): + """Initialize Shodan provider with session-specific configuration.""" super().__init__( name="shodan", - rate_limit=60, # Shodan API has various rate limits depending on plan - timeout=30 + rate_limit=60, + timeout=30, + session_config=session_config ) self.base_url = "https://api.shodan.io" - self.api_key = config.get_api_key('shodan') + self.api_key = self.config.get_api_key('shodan') + + def is_available(self) -> bool: + """Check if Shodan provider is available (has valid API key in this session).""" + return self.api_key is not None and len(self.api_key.strip()) > 0 def get_name(self) -> str: """Return the provider name.""" return "shodan" - - def is_available(self) -> bool: - """ - Check if Shodan provider is available (has valid API key). - """ - return self.api_key is not None and len(self.api_key.strip()) > 0 + def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ @@ -48,7 +47,7 @@ class ShodanProvider(BaseProvider): Returns: List of relationships discovered from Shodan data """ - if not self._is_valid_domain(domain) or not self.is_available(): + if not _is_valid_domain(domain) or not self.is_available(): return [] relationships = [] @@ -109,7 +108,7 @@ class ShodanProvider(BaseProvider): # Also create relationships to other hostnames on the same IP for hostname in hostnames: - if hostname != domain and self._is_valid_domain(hostname): + if hostname != domain and _is_valid_domain(hostname): hostname_raw_data = { 'shared_ip': ip_address, 'all_hostnames': hostnames, @@ -150,7 +149,7 @@ class ShodanProvider(BaseProvider): Returns: List of relationships discovered from Shodan IP data """ - if not self._is_valid_ip(ip) or not self.is_available(): + if not _is_valid_ip(ip) or not self.is_available(): return [] relationships = [] @@ -170,7 +169,7 @@ class ShodanProvider(BaseProvider): # Extract hostname relationships hostnames = data.get('hostnames', []) for hostname in hostnames: - if self._is_valid_domain(hostname): + if _is_valid_domain(hostname): raw_data = { 'ip_address': ip, 'hostname': hostname, @@ -280,7 +279,7 @@ class ShodanProvider(BaseProvider): Returns: List of service information dictionaries """ - if not self._is_valid_ip(ip) or not self.is_available(): + if not _is_valid_ip(ip) or not self.is_available(): return [] try: diff --git a/providers/virustotal_provider.py b/providers/virustotal_provider.py index 9949810..0d75f0d 100644 --- a/providers/virustotal_provider.py +++ b/providers/virustotal_provider.py @@ -4,38 +4,37 @@ Discovers domain relationships through passive DNS and URL analysis. """ import json -from typing import List, Dict, Any, Tuple, Optional +from typing import List, Dict, Any, Tuple from .base_provider import BaseProvider +from utils.helpers import _is_valid_ip, _is_valid_domain from core.graph_manager import RelationshipType -from config import config class VirusTotalProvider(BaseProvider): """ Provider for querying VirusTotal API for passive DNS and domain reputation data. - Requires valid API key and strictly respects free tier rate limits. + Now uses session-specific API keys and rate limits. """ - def __init__(self): - """Initialize VirusTotal provider with strict rate limiting for free tier.""" + def __init__(self, session_config=None): + """Initialize VirusTotal provider with session-specific configuration.""" super().__init__( name="virustotal", rate_limit=4, # Free tier: 4 requests per minute - timeout=30 + timeout=30, + session_config=session_config ) self.base_url = "https://www.virustotal.com/vtapi/v2" - self.api_key = config.get_api_key('virustotal') + self.api_key = self.config.get_api_key('virustotal') + + def is_available(self) -> bool: + """Check if VirusTotal provider is available (has valid API key in this session).""" + return self.api_key is not None and len(self.api_key.strip()) > 0 def get_name(self) -> str: """Return the provider name.""" return "virustotal" - def is_available(self) -> bool: - """ - Check if VirusTotal provider is available (has valid API key). - """ - return self.api_key is not None and len(self.api_key.strip()) > 0 - def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query VirusTotal for domain information including passive DNS. @@ -46,7 +45,7 @@ class VirusTotalProvider(BaseProvider): Returns: List of relationships discovered from VirusTotal data """ - if not self._is_valid_domain(domain) or not self.is_available(): + if not _is_valid_domain(domain) or not self.is_available(): return [] relationships = [] @@ -71,7 +70,7 @@ class VirusTotalProvider(BaseProvider): Returns: List of relationships discovered from VirusTotal IP data """ - if not self._is_valid_ip(ip) or not self.is_available(): + if not _is_valid_ip(ip) or not self.is_available(): return [] relationships = [] @@ -114,7 +113,7 @@ class VirusTotalProvider(BaseProvider): ip_address = resolution.get('ip_address') last_resolved = resolution.get('last_resolved') - if ip_address and self._is_valid_ip(ip_address): + if ip_address and _is_valid_ip(ip_address): raw_data = { 'domain': domain, 'ip_address': ip_address, @@ -142,7 +141,7 @@ class VirusTotalProvider(BaseProvider): # Extract subdomains subdomains = data.get('subdomains', []) for subdomain in subdomains: - if subdomain != domain and self._is_valid_domain(subdomain): + if subdomain != domain and _is_valid_domain(subdomain): raw_data = { 'parent_domain': domain, 'subdomain': subdomain, @@ -200,7 +199,7 @@ class VirusTotalProvider(BaseProvider): hostname = resolution.get('hostname') last_resolved = resolution.get('last_resolved') - if hostname and self._is_valid_domain(hostname): + if hostname and _is_valid_domain(hostname): raw_data = { 'ip_address': ip, 'hostname': hostname, @@ -254,7 +253,7 @@ class VirusTotalProvider(BaseProvider): Returns: Dictionary containing reputation data """ - if not self._is_valid_domain(domain) or not self.is_available(): + if not _is_valid_domain(domain) or not self.is_available(): return {} try: @@ -293,7 +292,7 @@ class VirusTotalProvider(BaseProvider): Returns: Dictionary containing reputation data """ - if not self._is_valid_ip(ip) or not self.is_available(): + if not _is_valid_ip(ip) or not self.is_available(): return {} try: diff --git a/static/css/main.css b/static/css/main.css index e89a5d2..c0ba131 100644 --- a/static/css/main.css +++ b/static/css/main.css @@ -318,17 +318,13 @@ input[type="text"]:focus, select:focus { } .graph-container { - height: 500px; + height: 800px; position: relative; background-color: #1a1a1a; border-top: 1px solid #444; transition: height 0.3s ease; } -.graph-container.expanded { - height: 700px; -} - .graph-controls { position: absolute; top: 10px; @@ -535,29 +531,6 @@ input[type="text"]:focus, select:focus { box-shadow: 0 4px 6px rgba(0,0,0,0.3); } -.node-info-title { - color: #00ff41; - font-weight: bold; - margin-bottom: 0.5rem; - border-bottom: 1px solid #444; - padding-bottom: 0.25rem; -} - -.node-info-detail { - margin-bottom: 0.25rem; - display: flex; - justify-content: space-between; -} - -.node-info-label { - color: #999; -} - -.node-info-value { - color: #c7c7c7; - font-weight: 500; -} - /* Footer */ .footer { background-color: #0a0a0a; diff --git a/static/js/graph.js b/static/js/graph.js index a438d37..9d6af11 100644 --- a/static/js/graph.js +++ b/static/js/graph.js @@ -233,7 +233,6 @@ class GraphManager { const nodeId = params.node; const node = this.nodes.get(nodeId); if (node) { - this.showNodeInfoPopup(params.pointer.DOM, node); this.highlightConnectedNodes(nodeId, true); } }); @@ -243,19 +242,6 @@ class GraphManager { this.clearHoverHighlights(); }); - // Edge hover events - this.network.on('hoverEdge', (params) => { - const edgeId = params.edge; - const edge = this.edges.get(edgeId); - if (edge) { - this.showEdgeInfo(params.pointer.DOM, edge); - } - }); - - this.network.on('blurEdge', () => { - this.hideNodeInfoPopup(); - }); - // Double-click to focus on node this.network.on('doubleClick', (params) => { if (params.nodes.length > 0) { @@ -347,7 +333,6 @@ class GraphManager { const processedNode = { id: node.id, label: this.formatNodeLabel(node.id, node.type), - title: this.createNodeTooltip(node), color: this.getNodeColor(node.type), size: this.getNodeSize(node.type), borderColor: this.getNodeBorderColor(node.type), @@ -373,11 +358,14 @@ class GraphManager { } // Style based on certificate validity - if (node.has_valid_cert === true) { - processedNode.borderColor = '#00ff41'; // Green for valid cert - } else if (node.has_valid_cert === false) { - processedNode.borderColor = '#ff9900'; // Amber for expired/no cert - processedNode.borderDashes = [5, 5]; + if (node.type === 'domain') { + if (node.metadata && node.metadata.has_valid_cert === true) { + processedNode.color = '#00ff41'; // Bright green for valid cert + processedNode.borderColor = '#00aa2e'; + } else if (node.metadata && node.metadata.has_valid_cert === false) { + processedNode.color = '#888888'; // Muted grey color + processedNode.borderColor = '#666666'; // Darker grey border + } } return processedNode; @@ -457,9 +445,9 @@ class GraphManager { const colors = { 'domain': '#00ff41', // Green 'ip': '#ff9900', // Amber - 'certificate': '#c7c7c7', // Gray 'asn': '#00aaff', // Blue - 'large_entity': '#ff6b6b' // Red for large entities + 'large_entity': '#ff6b6b', // Red for large entities + 'dns_record': '#999999' }; return colors[nodeType] || '#ffffff'; } @@ -474,8 +462,8 @@ class GraphManager { const borderColors = { 'domain': '#00aa2e', 'ip': '#cc7700', - 'certificate': '#999999', - 'asn': '#0088cc' + 'asn': '#0088cc', + 'dns_record': '#999999' }; return borderColors[nodeType] || '#666666'; } @@ -489,8 +477,8 @@ class GraphManager { const sizes = { 'domain': 12, 'ip': 14, - 'certificate': 10, - 'asn': 16 + 'asn': 16, + 'dns_record': 8 }; return sizes[nodeType] || 12; } @@ -504,8 +492,8 @@ class GraphManager { const shapes = { 'domain': 'dot', 'ip': 'square', - 'certificate': 'diamond', - 'asn': 'triangle' + 'asn': 'triangle', + 'dns_record': 'hexagon' }; return shapes[nodeType] || 'dot'; } @@ -541,26 +529,7 @@ class GraphManager { } /** - * Create node tooltip - * @param {Object} node - Node data - * @returns {string} HTML tooltip content - */ - createNodeTooltip(node) { - let tooltip = `