This commit is contained in:
overcuriousity 2025-09-11 14:01:15 +02:00
parent 2d485c5703
commit d3e1fcf35f
18 changed files with 1806 additions and 843 deletions

281
app.py
View File

@ -1,20 +1,65 @@
"""
Flask application entry point for DNSRecon web interface.
Provides REST API endpoints and serves the web interface.
Provides REST API endpoints and serves the web interface with user session support.
Enhanced with better session debugging and isolation.
"""
import json
import traceback
from flask import Flask, render_template, request, jsonify, send_file
from datetime import datetime, timezone
from flask import Flask, render_template, request, jsonify, send_file, session
from datetime import datetime, timezone, timedelta
import io
from core.scanner import scanner
from core.session_manager import session_manager
from config import config
app = Flask(__name__)
app.config['SECRET_KEY'] = 'dnsrecon-dev-key-change-in-production'
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2) # 2 hour session lifetime
def get_user_scanner():
"""
Get or create scanner instance for current user session with enhanced debugging.
Returns:
Tuple of (session_id, scanner_instance)
"""
# Get current Flask session info for debugging
current_flask_session_id = session.get('dnsrecon_session_id')
client_ip = request.remote_addr
user_agent = request.headers.get('User-Agent', '')[:100] # Truncate for logging
print("=== SESSION DEBUG ===")
print(f"Client IP: {client_ip}")
print(f"User Agent: {user_agent}")
print(f"Flask Session ID: {current_flask_session_id}")
print(f"Flask Session Keys: {list(session.keys())}")
# Try to get existing session
if current_flask_session_id:
existing_scanner = session_manager.get_session(current_flask_session_id)
if existing_scanner:
print(f"Using existing session: {current_flask_session_id}")
print(f"Scanner status: {existing_scanner.status}")
return current_flask_session_id, existing_scanner
else:
print(f"Session {current_flask_session_id} not found in session manager")
# Create new session
print("Creating new session...")
new_session_id = session_manager.create_session()
new_scanner = session_manager.get_session(new_session_id)
# Store in Flask session
session['dnsrecon_session_id'] = new_session_id
session.permanent = True
print(f"Created new session: {new_session_id}")
print(f"New scanner status: {new_scanner.status}")
print("=== END SESSION DEBUG ===")
return new_session_id, new_scanner
@app.route('/')
@ -26,13 +71,8 @@ def index():
@app.route('/api/scan/start', methods=['POST'])
def start_scan():
"""
Start a new reconnaissance scan.
Expects JSON payload:
{
"target_domain": "example.com",
"max_depth": 2
}
Start a new reconnaissance scan for the current user session.
Enhanced with better error handling and debugging.
"""
print("=== API: /api/scan/start called ===")
@ -68,26 +108,62 @@ def start_scan():
'error': 'Max depth must be an integer between 1 and 5'
}), 400
print("Validation passed, calling scanner.start_scan...")
print("Validation passed, getting user scanner...")
# Get user-specific scanner with enhanced debugging
user_session_id, scanner = get_user_scanner()
print(f"Using session: {user_session_id}")
print(f"Scanner object ID: {id(scanner)}")
print(f"Scanner status before start: {scanner.status}")
# Additional safety check - if scanner is somehow in running state, force reset
if scanner.status == 'running':
print(f"WARNING: Scanner in session {user_session_id} was already running - forcing reset")
scanner.stop_scan()
# Give it a moment to stop
import time
time.sleep(1)
# If still running, force status reset
if scanner.status == 'running':
print("WARNING: Force resetting scanner status from 'running' to 'idle'")
scanner.status = 'idle'
# Start scan
print(f"Calling start_scan on scanner {id(scanner)}...")
success = scanner.start_scan(target_domain, max_depth)
print(f"scanner.start_scan returned: {success}")
print(f"Scanner status after start attempt: {scanner.status}")
if success:
session_id = scanner.logger.session_id
print(f"Scan started successfully with session ID: {session_id}")
scan_session_id = scanner.logger.session_id
print(f"Scan started successfully with scan session ID: {scan_session_id}")
return jsonify({
'success': True,
'message': 'Scan started successfully',
'scan_id': session_id
'scan_id': scan_session_id,
'user_session_id': user_session_id,
'debug_info': {
'scanner_object_id': id(scanner),
'scanner_status': scanner.status
}
})
else:
print("ERROR: Scanner returned False")
# Provide more detailed error information
error_details = {
'scanner_status': scanner.status,
'scanner_object_id': id(scanner),
'session_id': user_session_id,
'providers_count': len(scanner.providers) if hasattr(scanner, 'providers') else 0
}
return jsonify({
'success': False,
'error': 'Failed to start scan (scan may already be running)'
'error': f'Failed to start scan (scanner status: {scanner.status})',
'debug_info': error_details
}), 409
except Exception as e:
@ -98,24 +174,28 @@ def start_scan():
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/scan/stop', methods=['POST'])
def stop_scan():
"""Stop the current scan."""
"""Stop the current scan for the user session."""
print("=== API: /api/scan/stop called ===")
try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner()
print(f"Stopping scan for session: {user_session_id}")
success = scanner.stop_scan()
if success:
return jsonify({
'success': True,
'message': 'Scan stop requested'
'message': 'Scan stop requested',
'user_session_id': user_session_id
})
else:
return jsonify({
'success': False,
'error': 'No active scan to stop'
'error': 'No active scan to stop for this session'
}), 400
except Exception as e:
@ -129,9 +209,14 @@ def stop_scan():
@app.route('/api/scan/status', methods=['GET'])
def get_scan_status():
"""Get current scan status and progress."""
"""Get current scan status and progress for the user session."""
try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner()
status = scanner.get_scan_status()
status['user_session_id'] = user_session_id
return jsonify({
'success': True,
'status': status
@ -148,12 +233,16 @@ def get_scan_status():
@app.route('/api/graph', methods=['GET'])
def get_graph_data():
"""Get current graph data for visualization."""
"""Get current graph data for visualization for the user session."""
try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner()
graph_data = scanner.get_graph_data()
return jsonify({
'success': True,
'graph': graph_data
'graph': graph_data,
'user_session_id': user_session_id
})
except Exception as e:
@ -167,15 +256,25 @@ def get_graph_data():
@app.route('/api/export', methods=['GET'])
def export_results():
"""Export complete scan results as downloadable JSON."""
"""Export complete scan results as downloadable JSON for the user session."""
try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner()
# Get complete results
results = scanner.export_results()
# Add session information to export
results['export_metadata'] = {
'user_session_id': user_session_id,
'export_timestamp': datetime.now(timezone.utc).isoformat(),
'export_type': 'user_session_results'
}
# Create filename with timestamp
timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
target = scanner.current_target or 'unknown'
filename = f"dnsrecon_{target}_{timestamp}.json"
filename = f"dnsrecon_{target}_{timestamp}_{user_session_id[:8]}.json"
# Create in-memory file
json_data = json.dumps(results, indent=2, ensure_ascii=False)
@ -199,10 +298,13 @@ def export_results():
@app.route('/api/providers', methods=['GET'])
def get_providers():
"""Get information about available providers."""
"""Get information about available providers for the user session."""
print("=== API: /api/providers called ===")
try:
# Get user-specific scanner
user_session_id, scanner = get_user_scanner()
provider_stats = scanner.get_provider_statistics()
# Add configuration information
@ -215,10 +317,11 @@ def get_providers():
'requires_api_key': provider_name in ['shodan', 'virustotal']
}
print(f"Returning provider info: {list(provider_info.keys())}")
print(f"Returning provider info for session {user_session_id}: {list(provider_info.keys())}")
return jsonify({
'success': True,
'providers': provider_info
'providers': provider_info,
'user_session_id': user_session_id
})
except Exception as e:
@ -233,13 +336,7 @@ def get_providers():
@app.route('/api/config/api-keys', methods=['POST'])
def set_api_keys():
"""
Set API keys for providers (stored in memory only).
Expects JSON payload:
{
"shodan": "api_key_here",
"virustotal": "api_key_here"
}
Set API keys for providers for the user session only.
"""
try:
data = request.get_json()
@ -250,22 +347,27 @@ def set_api_keys():
'error': 'No API keys provided'
}), 400
# Get user-specific scanner and config
user_session_id, scanner = get_user_scanner()
session_config = scanner.config
updated_providers = []
for provider, api_key in data.items():
if provider in ['shodan', 'virustotal'] and api_key.strip():
success = config.set_api_key(provider, api_key.strip())
success = session_config.set_api_key(provider, api_key.strip())
if success:
updated_providers.append(provider)
if updated_providers:
# Reinitialize scanner providers
# Reinitialize scanner providers for this session only
scanner._initialize_providers()
return jsonify({
'success': True,
'message': f'API keys updated for: {", ".join(updated_providers)}',
'updated_providers': updated_providers
'message': f'API keys updated for session {user_session_id}: {", ".join(updated_providers)}',
'updated_providers': updated_providers,
'user_session_id': user_session_id
})
else:
return jsonify({
@ -281,14 +383,99 @@ def set_api_keys():
'error': f'Internal server error: {str(e)}'
}), 500
except Exception as e:
print(f"ERROR: Exception in set_api_keys endpoint: {e}")
traceback.print_exc()
return jsonify({
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/session/info', methods=['GET'])
def get_session_info():
"""Get information about the current user session."""
try:
user_session_id, scanner = get_user_scanner()
session_info = session_manager.get_session_info(user_session_id)
return jsonify({
'success': True,
'session_info': session_info
})
except Exception as e:
print(f"ERROR: Exception in get_session_info endpoint: {e}")
traceback.print_exc()
return jsonify({
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/session/terminate', methods=['POST'])
def terminate_session():
"""Terminate the current user session."""
try:
user_session_id = session.get('dnsrecon_session_id')
if user_session_id:
success = session_manager.terminate_session(user_session_id)
# Clear Flask session
session.pop('dnsrecon_session_id', None)
return jsonify({
'success': success,
'message': 'Session terminated' if success else 'Session not found'
})
else:
return jsonify({
'success': False,
'error': 'No active session to terminate'
}), 400
except Exception as e:
print(f"ERROR: Exception in terminate_session endpoint: {e}")
traceback.print_exc()
return jsonify({
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/admin/sessions', methods=['GET'])
def list_sessions():
"""Admin endpoint to list all active sessions."""
try:
sessions = session_manager.list_active_sessions()
stats = session_manager.get_statistics()
return jsonify({
'success': True,
'sessions': sessions,
'statistics': stats
})
except Exception as e:
print(f"ERROR: Exception in list_sessions endpoint: {e}")
traceback.print_exc()
return jsonify({
'success': False,
'error': f'Internal server error: {str(e)}'
}), 500
@app.route('/api/health', methods=['GET'])
def health_check():
"""Health check endpoint with enhanced Phase 2 information."""
try:
# Get session stats
session_stats = session_manager.get_statistics()
return jsonify({
'success': True,
'status': 'healthy',
'timestamp': datetime.now(datetime.UTC).isoformat(),
'timestamp': datetime.now(timezone.utc).isoformat(),
'version': '1.0.0-phase2',
'phase': 2,
'features': {
@ -297,10 +484,18 @@ def health_check():
'real_time_updates': True,
'api_key_management': True,
'enhanced_visualization': True,
'retry_logic': True
'retry_logic': True,
'user_sessions': True,
'session_isolation': True
},
'providers_available': len(scanner.providers) if hasattr(scanner, 'providers') else 0
'session_statistics': session_stats
})
except Exception as e:
print(f"ERROR: Exception in health_check endpoint: {e}")
return jsonify({
'success': False,
'error': f'Health check failed: {str(e)}'
}), 500
@app.errorhandler(404)
@ -324,7 +519,7 @@ def internal_error(error):
if __name__ == '__main__':
print("Starting DNSRecon Flask application...")
print("Starting DNSRecon Flask application with user session support...")
# Load configuration from environment
config.load_from_env()

View File

@ -5,8 +5,10 @@ Phase 2: Enhanced with concurrent processing and real-time capabilities.
"""
from .graph_manager import GraphManager, NodeType, RelationshipType
from .scanner import Scanner, ScanStatus, scanner
from .scanner import Scanner, ScanStatus # Remove 'scanner' global instance
from .logger import ForensicLogger, get_forensic_logger, new_session
from .session_manager import session_manager # Add session manager
from .session_config import SessionConfig, create_session_config # Add session config
__all__ = [
'GraphManager',
@ -14,10 +16,13 @@ __all__ = [
'RelationshipType',
'Scanner',
'ScanStatus',
'scanner',
# 'scanner', # Remove this - no more global scanner
'ForensicLogger',
'get_forensic_logger',
'new_session'
'new_session',
'session_manager', # Add this
'SessionConfig', # Add this
'create_session_config' # Add this
]
__version__ = "1.0.0-phase2"

View File

@ -3,13 +3,10 @@ Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata.
"""
import json
import threading
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple, Set
from typing import Dict, List, Any, Optional, Tuple
from enum import Enum
from datetime import timezone
from collections import defaultdict
import networkx as nx
@ -18,8 +15,8 @@ class NodeType(Enum):
"""Enumeration of supported node types."""
DOMAIN = "domain"
IP = "ip"
CERTIFICATE = "certificate"
ASN = "asn"
DNS_RECORD = "dns_record"
LARGE_ENTITY = "large_entity"
@ -43,6 +40,7 @@ class RelationshipType(Enum):
TLSA_RECORD = ("tlsa_record", 0.7)
NAPTR_RECORD = ("naptr_record", 0.7)
SPF_RECORD = ("spf_record", 0.7)
DNS_RECORD = ("dns_record", 0.8)
PASSIVE_DNS = ("passive_dns", 0.6)
ASN_MEMBERSHIP = ("asn", 0.7)
@ -115,8 +113,7 @@ class GraphManager:
Returns:
bool: True if edge was added, False if it already exists
"""
#with self.lock:
# Ensure both nodes exist
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
# If the target node is a subdomain, it should be added.
# The scanner will handle this logic.
@ -149,12 +146,10 @@ class GraphManager:
def get_node_count(self) -> int:
"""Get total number of nodes in the graph."""
#with self.lock:
return self.graph.number_of_nodes()
def get_edge_count(self) -> int:
"""Get total number of edges in the graph."""
#with self.lock:
return self.graph.number_of_edges()
def get_nodes_by_type(self, node_type: NodeType) -> List[str]:
@ -167,7 +162,6 @@ class GraphManager:
Returns:
List of node identifiers
"""
#with self.lock:
return [
node_id for node_id, attributes in self.graph.nodes(data=True)
if attributes.get('type') == node_type.value
@ -183,7 +177,6 @@ class GraphManager:
Returns:
List of neighboring node identifiers
"""
#with self.lock:
if not self.graph.has_node(node_id):
return []
@ -201,7 +194,6 @@ class GraphManager:
Returns:
List of tuples (source, target, attributes)
"""
#with self.lock:
return [
(source, target, attributes)
for source, target, attributes in self.graph.edges(data=True)
@ -211,46 +203,12 @@ class GraphManager:
def get_graph_data(self) -> Dict[str, Any]:
"""
Export graph data for visualization.
Returns:
Dictionary containing nodes and edges for frontend visualization
Uses comprehensive metadata collected during scanning.
"""
#with self.lock:
nodes = []
edges = []
# Create a dictionary to hold aggregated data for each node
node_details = defaultdict(lambda: defaultdict(list))
for source, target, attributes in self.graph.edges(data=True):
provider = attributes.get('source_provider', 'unknown')
raw_data = attributes.get('raw_data', {})
if provider == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target)
# DNS data is always about the source node of the query
node_details[source]['dns_records'].append(f"{record_type}: {value}")
elif provider == 'crtsh':
# Data from crt.sh are domain names found in certificates (SANs)
node_details[source]['related_domains_san'].append(target)
elif provider == 'shodan':
# Shodan data is about the IP, which can be either the source or target
source_node_type = self.graph.nodes[source].get('type')
target_node_type = self.graph.nodes[target].get('type')
if source_node_type == 'ip':
node_details[source]['shodan'] = raw_data
elif target_node_type == 'ip':
node_details[target]['shodan'] = raw_data
elif provider == 'virustotal':
# VirusTotal data is about the source node of the query
node_details[source]['virustotal'] = raw_data
# Format nodes for visualization
# Create nodes with the comprehensive metadata already collected
for node_id, attributes in self.graph.nodes(data=True):
node_data = {
'id': node_id,
@ -260,18 +218,15 @@ class GraphManager:
'added_timestamp': attributes.get('added_timestamp')
}
# Add the aggregated details to the metadata
if node_id in node_details:
for key, value in node_details[node_id].items():
# Use a set to avoid adding duplicate entries to lists
if key in node_data['metadata'] and isinstance(node_data['metadata'][key], list):
existing_values = set(node_data['metadata'][key])
new_values = [v for v in value if v not in existing_values]
node_data['metadata'][key].extend(new_values)
else:
node_data['metadata'][key] = value
# Handle certificate node labeling
if node_id.startswith('cert_'):
# For certificate nodes, create a more informative label
cert_metadata = node_data['metadata']
issuer = cert_metadata.get('issuer_name', 'Unknown')
valid_status = "" if cert_metadata.get('is_currently_valid') else ""
node_data['label'] = f"Certificate {valid_status}\n{issuer[:30]}..."
# Color coding by type - now returns color objects for enhanced visualization
# Color coding by type
type_colors = {
'domain': {
'background': '#00ff41',
@ -285,18 +240,18 @@ class GraphManager:
'highlight': {'background': '#ffbb44', 'border': '#ff9900'},
'hover': {'background': '#ffaa22', 'border': '#dd8800'}
},
'certificate': {
'background': '#c7c7c7',
'border': '#999999',
'highlight': {'background': '#e0e0e0', 'border': '#c7c7c7'},
'hover': {'background': '#d4d4d4', 'border': '#aaaaaa'}
},
'asn': {
'background': '#00aaff',
'border': '#0088cc',
'highlight': {'background': '#44ccff', 'border': '#00aaff'},
'hover': {'background': '#22bbff', 'border': '#0099dd'}
},
'dns_record': {
'background': '#9d4edd',
'border': '#7b2cbf',
'highlight': {'background': '#c77dff', 'border': '#9d4edd'},
'hover': {'background': '#b392f0', 'border': '#8b5cf6'}
},
'large_entity': {
'background': '#ff6b6b',
'border': '#cc3a3a',
@ -306,15 +261,17 @@ class GraphManager:
}
node_color_config = type_colors.get(attributes.get('type', 'unknown'), type_colors['domain'])
node_data['color'] = node_color_config
# Pass the has_valid_cert metadata for styling
if 'metadata' in attributes and 'has_valid_cert' in attributes['metadata']:
node_data['has_valid_cert'] = attributes['metadata']['has_valid_cert']
# Add certificate validity indicator if available
metadata = node_data['metadata']
if 'certificate_data' in metadata and 'has_valid_cert' in metadata['certificate_data']:
node_data['has_valid_cert'] = metadata['certificate_data']['has_valid_cert']
nodes.append(node_data)
# Format edges for visualization
# Create edges (unchanged from original)
for source, target, attributes in self.graph.edges(data=True):
edge_data = {
'from': source,
@ -376,7 +333,6 @@ class GraphManager:
Returns:
Dictionary containing complete graph data with metadata
"""
#with self.lock:
# Get basic graph data
graph_data = self.get_graph_data()
@ -427,7 +383,6 @@ class GraphManager:
Returns:
Dictionary containing various graph metrics
"""
#with self.lock:
stats = {
'basic_metrics': {
'total_nodes': self.graph.number_of_nodes(),
@ -462,7 +417,6 @@ class GraphManager:
def clear(self) -> None:
"""Clear all nodes and edges from the graph."""
#with self.lock:
self.graph.clear()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time

View File

@ -3,7 +3,6 @@ Forensic logging system for DNSRecon tool.
Provides structured audit trail for all reconnaissance activities.
"""
import json
import logging
import threading
from datetime import datetime
@ -109,7 +108,6 @@ class ForensicLogger:
target_indicator: The indicator being investigated
discovery_context: Context of how this indicator was discovered
"""
#with self.lock:
api_request = APIRequest(
timestamp=datetime.now(timezone.utc).isoformat(),
provider=provider,
@ -152,7 +150,6 @@ class ForensicLogger:
raw_data: Raw data from provider response
discovery_method: Method used to discover relationship
"""
#with self.lock:
relationship = RelationshipDiscovery(
timestamp=datetime.now(timezone.utc).isoformat(),
source_node=source_node,
@ -178,12 +175,10 @@ class ForensicLogger:
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
#with self.lock:
self.session_metadata['target_domains'].add(target_domain)
def log_scan_complete(self) -> None:
"""Log the completion of a reconnaissance scan."""
#with self.lock:
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
@ -199,7 +194,6 @@ class ForensicLogger:
Returns:
Dictionary containing complete session audit trail
"""
#with self.lock:
return {
'session_metadata': self.session_metadata.copy(),
'api_requests': [asdict(req) for req in self.api_requests],
@ -214,7 +208,6 @@ class ForensicLogger:
Returns:
Dictionary containing summary statistics
"""
#with self.lock:
provider_stats = {}
for provider in self.session_metadata['providers_used']:
provider_requests = [req for req in self.api_requests if req.provider == provider]

View File

@ -4,14 +4,14 @@ Coordinates data gathering from multiple providers and builds the infrastructure
"""
import threading
import time
import traceback
from typing import List, Set, Dict, Any, Optional, Tuple
from typing import List, Set, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
from collections import defaultdict
from core.graph_manager import GraphManager, NodeType, RelationshipType
from core.logger import get_forensic_logger, new_session
from utils.helpers import _is_valid_ip, _is_valid_domain
from providers.crtsh_provider import CrtShProvider
from providers.dns_provider import DNSProvider
from providers.shodan_provider import ShodanProvider
@ -31,21 +31,27 @@ class ScanStatus:
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
Manages multi-provider data gathering and graph construction with concurrent processing.
Now supports per-session configuration for multi-user isolation.
"""
def __init__(self):
"""Initialize scanner with all available providers and empty graph."""
def __init__(self, session_config=None):
"""Initialize scanner with session-specific configuration."""
print("Initializing Scanner instance...")
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
self.config = session_config
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event() # Use a threading.Event for safer signaling
self.stop_event = threading.Event()
self.scan_thread = None
# Scanning progress tracking
@ -54,11 +60,11 @@ class Scanner:
self.current_indicator = ""
# Concurrent processing configuration
self.max_workers = config.max_concurrent_requests
self.executor = None # Keep a reference to the executor
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Initialize providers
print("Calling _initialize_providers...")
# Initialize providers with session config
print("Calling _initialize_providers with session config...")
self._initialize_providers()
# Initialize logger
@ -73,10 +79,10 @@ class Scanner:
raise
def _initialize_providers(self) -> None:
"""Initialize all available providers based on configuration."""
"""Initialize all available providers based on session configuration."""
self.providers = []
print("Initializing providers...")
print("Initializing providers with session config...")
# Always add free providers
free_providers = [
@ -85,12 +91,15 @@ class Scanner:
]
for provider_name, provider_class in free_providers:
if config.is_provider_enabled(provider_name):
if self.config.is_provider_enabled(provider_name):
try:
provider = provider_class()
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully")
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available")
except Exception as e:
@ -104,23 +113,41 @@ class Scanner:
]
for provider_name, provider_class in api_providers:
if config.is_provider_enabled(provider_name):
if self.config.is_provider_enabled(provider_name):
try:
provider = provider_class()
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully")
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available (API key required)")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
print(f"Initialized {len(self.providers)} providers")
print(f"Initialized {len(self.providers)} providers for session")
def update_session_config(self, new_config) -> None:
"""
Update session configuration and reinitialize providers.
Args:
new_config: New SessionConfig instance
"""
print("Updating session configuration...")
self.config = new_config
self.max_workers = self.config.max_concurrent_requests
self._initialize_providers()
print("Session configuration updated")
def start_scan(self, target_domain: str, max_depth: int = 2) -> bool:
"""
Start a new reconnaissance scan with concurrent processing.
Enhanced with better debugging and state validation.
Args:
target_domain: Initial domain to investigate
@ -129,29 +156,36 @@ class Scanner:
Returns:
bool: True if scan started successfully
"""
print(f"Scanner.start_scan called with target='{target_domain}', depth={max_depth}")
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Scanner status: {self.status}")
print(f"Target domain: '{target_domain}', Max depth: {max_depth}")
print(f"Available providers: {len(self.providers) if hasattr(self, 'providers') else 0}")
try:
if self.status == ScanStatus.RUNNING:
print("Scan already running, rejecting new scan")
print(f"ERROR: Scan already running in scanner {id(self)}, rejecting new scan")
print(f"Current target: {self.current_target}")
print(f"Current depth: {self.current_depth}")
return False
# Check if we have any providers
if not self.providers:
print("No providers available, cannot start scan")
if not hasattr(self, 'providers') or not self.providers:
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
return False
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
# Stop any existing scan thread
if self.scan_thread and self.scan_thread.is_alive():
print("Stopping existing scan thread...")
print(f"Stopping existing scan thread in scanner {id(self)}...")
self.stop_event.set()
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
print("WARNING: Could not stop existing thread")
print(f"WARNING: Could not stop existing thread in scanner {id(self)}")
return False
# Reset state
print("Resetting scanner state...")
print(f"Resetting scanner {id(self)} state...")
self.graph.clear()
self.current_target = target_domain.lower().strip()
self.max_depth = max_depth
@ -162,11 +196,11 @@ class Scanner:
self.current_indicator = self.current_target
# Start new forensic session
print("Starting new forensic session...")
print(f"Starting new forensic session for scanner {id(self)}...")
self.logger = new_session()
# Start scan in separate thread
print("Starting scan thread...")
print(f"Starting scan thread for scanner {id(self)}...")
self.scan_thread = threading.Thread(
target=self._execute_scan_async,
args=(self.current_target, max_depth),
@ -174,14 +208,15 @@ class Scanner:
)
self.scan_thread.start()
print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===")
return True
except Exception as e:
print(f"ERROR: Exception in start_scan: {e}")
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc()
return False
def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
async def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
"""
Execute the reconnaissance scan asynchronously with concurrent provider queries.
@ -210,7 +245,7 @@ class Scanner:
processed_domains = set()
all_discovered_ips = set()
print(f"Starting BFS exploration...")
print("Starting BFS exploration...")
for depth in range(max_depth + 1):
if self.stop_event.is_set():
@ -269,7 +304,7 @@ class Scanner:
self.executor.shutdown(wait=False, cancel_futures=True)
stats = self.graph.get_statistics()
print(f"Final scan statistics:")
print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Domains processed: {len(processed_domains)}")
@ -330,87 +365,243 @@ class Scanner:
def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]:
"""
Query all enabled providers for information about a domain.
Query all enabled providers for information about a domain and collect comprehensive metadata.
Creates appropriate node types and relationships based on discovered data.
"""
print(f"Querying {len(self.providers)} providers for domain: {domain}")
discovered_domains = set()
discovered_ips = set()
relationships_by_type = defaultdict(list)
all_relationships = []
# Comprehensive metadata collection for this domain
domain_metadata = {
'dns_records': [],
'related_domains_san': [],
'shodan': {},
'virustotal': {},
'certificate_data': {},
'passive_dns': [],
}
if not self.providers or self.stop_event.is_set():
return discovered_domains, discovered_ips
# Query all providers concurrently
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider
for provider in self.providers
}
for future in as_completed(future_to_provider):
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships")
# Process relationships and collect metadata
for rel in relationships:
relationships_by_type[rel[2]].append(rel)
source, target, rel_type, confidence, raw_data = rel
# Add provider info to the relationship
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
all_relationships.append(enhanced_rel)
# Collect metadata based on provider and relationship type
self._collect_node_metadata(domain, provider.get_name(), rel_type, target, raw_data, domain_metadata)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for {domain}: {e}")
# Add the domain node with comprehensive metadata
self.graph.add_node(domain, NodeType.DOMAIN, metadata=domain_metadata)
# Group relationships by type for large entity handling
relationships_by_type = defaultdict(list)
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
relationships_by_type[rel_type].append((source, target, rel_type, confidence, raw_data, provider_name))
# Handle large entities (only for SAN certificates currently)
for rel_type, relationships in relationships_by_type.items():
if len(relationships) > config.large_entity_threshold and rel_type == RelationshipType.SAN_CERTIFICATE:
self._handle_large_entity(domain, relationships, rel_type, provider.get_name())
else:
for source, target, rel_type, confidence, raw_data in relationships:
# Determine if the target should create a new node
create_node = rel_type in [
RelationshipType.A_RECORD,
RelationshipType.AAAA_RECORD,
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.NS_RECORD,
RelationshipType.PTR_RECORD,
RelationshipType.SAN_CERTIFICATE
]
first_provider = relationships[0][5] if relationships else "multiple_providers"
self._handle_large_entity(domain, relationships, rel_type, first_provider)
# Remove these relationships from further processing
all_relationships = [rel for rel in all_relationships if not (rel[2] == rel_type and len(relationships_by_type[rel_type]) > config.large_entity_threshold)]
# Determine if the target should be subject to recursion
recurse = rel_type in [
RelationshipType.A_RECORD,
RelationshipType.AAAA_RECORD,
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.SAN_CERTIFICATE
]
# Track DNS records to create (avoid duplicates)
dns_records_to_create = {}
if create_node:
target_node_type = NodeType.IP if self._is_valid_ip(target) else NodeType.DOMAIN
self.graph.add_node(target, target_node_type)
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added relationship: {source} -> {target} ({rel_type.relationship_name})")
else:
# For records that don't create nodes, we still want to log the relationship
self.logger.log_relationship_discovery(
source_node=source,
target_node=target,
relationship_type=rel_type.relationship_name,
confidence_score=confidence,
provider=provider.name,
raw_data=raw_data,
discovery_method=f"dns_{rel_type.name.lower()}_record"
)
# Process remaining relationships
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
if self.stop_event.is_set():
break
if recurse:
if self._is_valid_ip(target):
# Determine how to handle the target based on relationship type and content
if _is_valid_ip(target):
# Create IP node and relationship
self.graph.add_node(target, NodeType.IP)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add to recursion if it's a direct resolution
if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]:
discovered_ips.add(target)
elif self._is_valid_domain(target):
elif target.startswith('AS') and target[2:].isdigit():
# Create ASN node and relationship
self.graph.add_node(target, NodeType.ASN)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added ASN relationship: {source} -> {target} ({rel_type.relationship_name})")
elif _is_valid_domain(target):
# Create domain node and relationship
self.graph.add_node(target, NodeType.DOMAIN)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added domain relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add to recursion for specific relationship types
recurse_types = [
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.SAN_CERTIFICATE,
RelationshipType.NS_RECORD,
RelationshipType.PASSIVE_DNS
]
if rel_type in recurse_types:
discovered_domains.add(target)
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs")
else:
# Handle DNS record content (TXT, SPF, CAA, etc.)
dns_record_types = [
RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD,
RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD,
RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD,
RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD,
RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD
]
if rel_type in dns_record_types:
# Create normalized DNS record identifier
record_type = rel_type.relationship_name.upper().replace('_RECORD', '')
record_content = target.strip()
# Create a unique identifier for this DNS record
content_hash = hash(record_content) & 0x7FFFFFFF
dns_record_id = f"{record_type}:{content_hash}"
# Track this DNS record for creation (avoid duplicates)
if dns_record_id not in dns_records_to_create:
dns_records_to_create[dns_record_id] = {
'content': record_content,
'type': record_type,
'domains': set(),
'raw_data': raw_data,
'provider_name': provider_name,
'confidence': confidence
}
# Add this domain to the DNS record's domain list
dns_records_to_create[dns_record_id]['domains'].add(source)
print(f"DNS record tracked: {source} -> {record_type} (content length: {len(record_content)})")
else:
# For other non-infrastructure targets, log but don't create nodes
print(f"Non-infrastructure relationship stored as metadata: {source} - {rel_type.relationship_name}: {target[:100]}")
# Create DNS record nodes and their relationships
for dns_record_id, record_info in dns_records_to_create.items():
if self.stop_event.is_set():
break
record_metadata = {
'record_type': record_info['type'],
'content': record_info['content'],
'content_hash': dns_record_id.split(':')[1],
'associated_domains': list(record_info['domains']),
'source_data': record_info['raw_data']
}
# Create the DNS record node
self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata)
# Connect each domain to this DNS record
for domain_name in record_info['domains']:
if self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD,
record_info['confidence'], record_info['provider_name'],
record_info['raw_data']):
print(f"Added DNS record relationship: {domain_name} -> {dns_record_id}")
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs, {len(dns_records_to_create)} DNS records")
return discovered_domains, discovered_ips
def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
"""
Collect and organize metadata for a node based on provider responses.
"""
if provider_name == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target)
# For non-infrastructure DNS records, store the full content
if record_type in ['TXT', 'SPF', 'CAA']:
dns_entry = f"{record_type}: {value}"
else:
dns_entry = f"{record_type}: {value}"
if dns_entry not in metadata['dns_records']:
metadata['dns_records'].append(dns_entry)
elif provider_name == 'crtsh':
if rel_type == RelationshipType.SAN_CERTIFICATE:
# Handle certificate data storage on domain nodes
domain_certs = raw_data.get('domain_certificates', {})
# Store certificate information for this domain
if node_id in domain_certs:
cert_summary = domain_certs[node_id]
# Update domain metadata with certificate information
metadata['certificate_data'] = cert_summary
metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False)
# Add related domains from shared certificates
if target not in metadata.get('related_domains_san', []):
if 'related_domains_san' not in metadata:
metadata['related_domains_san'] = []
metadata['related_domains_san'].append(target)
# Store shared certificate details for forensic analysis
shared_certs = raw_data.get('shared_certificates', [])
if shared_certs and 'shared_certificate_details' not in metadata:
metadata['shared_certificate_details'] = shared_certs
elif provider_name == 'shodan':
# Merge Shodan data (avoid overwriting)
for key, value in raw_data.items():
if key not in metadata['shodan'] or not metadata['shodan'][key]:
metadata['shodan'][key] = value
elif provider_name == 'virustotal':
# Merge VirusTotal data
for key, value in raw_data.items():
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
metadata['virustotal'][key] = value
# Add passive DNS entries
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata['passive_dns']:
metadata['passive_dns'].append(passive_entry)
def _handle_large_entity(self, source_domain: str, relationships: list, rel_type: RelationshipType, provider_name: str):
"""
Handles the creation of a large entity node when a threshold is exceeded.
@ -422,12 +613,24 @@ class Scanner:
def _query_providers_for_ip(self, ip: str) -> None:
"""
Query all enabled providers for information about an IP address.
Query all enabled providers for information about an IP address and collect comprehensive metadata.
"""
print(f"Querying {len(self.providers)} providers for IP: {ip}")
if not self.providers or self.stop_event.is_set():
return
# Comprehensive metadata collection for this IP
ip_metadata = {
'dns_records': [],
'passive_dns': [],
'shodan': {},
'virustotal': {},
'asn_data': {},
'hostnames': [],
}
all_relationships = [] # Store relationships with provider info
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider
@ -441,19 +644,85 @@ class Scanner:
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}")
for source, target, rel_type, confidence, raw_data in relationships:
if self._is_valid_domain(target):
# Add provider info to the relationship
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
all_relationships.append(enhanced_rel)
# Collect metadata for the IP
self._collect_ip_metadata(ip, provider.get_name(), rel_type, target, raw_data, ip_metadata)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
# Update the IP node with comprehensive metadata
self.graph.add_node(ip, NodeType.IP, metadata=ip_metadata)
# Process relationships with correct provider attribution
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
# Determine target node type
if _is_valid_domain(target):
target_node_type = NodeType.DOMAIN
elif target.startswith('AS'):
target_node_type = NodeType.ASN
else:
target_node_type = NodeType.IP
self.graph.add_node(source, NodeType.IP)
# Create/update target node
self.graph.add_node(target, target_node_type)
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
# Add relationship with correct provider attribution
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name}) from {provider_name}")
def _collect_ip_metadata(self, ip: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
"""
Collect and organize metadata for an IP node based on provider responses.
"""
if provider_name == 'dns':
if rel_type == RelationshipType.PTR_RECORD:
reverse_entry = f"PTR: {target}"
if reverse_entry not in metadata['dns_records']:
metadata['dns_records'].append(reverse_entry)
if target not in metadata['hostnames']:
metadata['hostnames'].append(target)
elif provider_name == 'shodan':
# Merge Shodan data
for key, value in raw_data.items():
if key not in metadata['shodan'] or not metadata['shodan'][key]:
metadata['shodan'][key] = value
# Collect hostname information
if 'hostname' in raw_data and raw_data['hostname'] not in metadata['hostnames']:
metadata['hostnames'].append(raw_data['hostname'])
if 'hostnames' in raw_data:
for hostname in raw_data['hostnames']:
if hostname not in metadata['hostnames']:
metadata['hostnames'].append(hostname)
elif provider_name == 'virustotal':
# Merge VirusTotal data
for key, value in raw_data.items():
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
metadata['virustotal'][key] = value
# Add passive DNS entries
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata['passive_dns']:
metadata['passive_dns'].append(passive_entry)
# Handle ASN relationships
if rel_type == RelationshipType.ASN_MEMBERSHIP:
metadata['asn_data'] = {
'asn': target,
'description': raw_data.get('org', ''),
'isp': raw_data.get('isp', ''),
'country': raw_data.get('country', '')
}
def _safe_provider_query_domain(self, provider, domain: str):
@ -478,12 +747,33 @@ class Scanner:
def stop_scan(self) -> bool:
"""
Request scan termination.
Request immediate scan termination with aggressive cancellation.
"""
try:
if self.status == ScanStatus.RUNNING:
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
# Signal all threads to stop
self.stop_event.set()
print("Scan stop requested")
# Close HTTP sessions in all providers to terminate ongoing requests
for provider in self.providers:
try:
if hasattr(provider, 'session'):
provider.session.close()
print(f"Closed HTTP session for provider: {provider.get_name()}")
except Exception as e:
print(f"Error closing session for {provider.get_name()}: {e}")
# Shutdown executor immediately with cancel_futures=True
if self.executor:
print("Shutting down executor with immediate cancellation...")
self.executor.shutdown(wait=False, cancel_futures=True)
# Give threads a moment to respond to cancellation, then force status change
threading.Timer(2.0, self._force_stop_completion).start()
print("Immediate termination requested - ongoing requests will be cancelled")
return True
print("No active scan to stop")
return False
@ -492,6 +782,13 @@ class Scanner:
traceback.print_exc()
return False
def _force_stop_completion(self):
"""Force completion of stop operation after timeout."""
if self.status == ScanStatus.RUNNING:
print("Forcing scan termination after timeout")
self.status = ScanStatus.STOPPED
self.logger.log_scan_complete()
def get_scan_status(self) -> Dict[str, Any]:
"""
Get current scan status and progress.
@ -561,16 +858,6 @@ class Scanner:
}
return export_data
def remove_provider(self, provider_name: str) -> bool:
"""
Remove a provider from the scanner.
"""
for i, provider in enumerate(self.providers):
if provider.get_name() == provider_name:
self.providers.pop(i)
return True
return False
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""
Get statistics for all providers.
@ -579,52 +866,3 @@ class Scanner:
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
"""
if not domain or len(domain) > 253:
return False
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False
class ScannerProxy:
def __init__(self):
self._scanner = None
print("ScannerProxy initialized")
def __getattr__(self, name):
if self._scanner is None:
print("Creating new Scanner instance...")
self._scanner = Scanner()
print("Scanner instance created")
return getattr(self._scanner, name)
# Global scanner instance
scanner = ScannerProxy()

126
core/session_config.py Normal file
View File

@ -0,0 +1,126 @@
"""
Per-session configuration management for DNSRecon.
Provides isolated configuration instances for each user session.
"""
import os
from typing import Dict, Optional
class SessionConfig:
"""
Session-specific configuration that inherits from global config
but maintains isolated API keys and provider settings.
"""
def __init__(self):
"""Initialize session config with global defaults."""
# Copy all attributes from global config
self.api_keys: Dict[str, Optional[str]] = {
'shodan': None,
'virustotal': None
}
# Default settings (copied from global config)
self.default_recursion_depth = 2
self.default_timeout = 30
self.max_concurrent_requests = 5
self.large_entity_threshold = 100
# Rate limiting settings (per session)
self.rate_limits = {
'crtsh': 60,
'virustotal': 4,
'shodan': 60,
'dns': 100
}
# Provider settings (per session)
self.enabled_providers = {
'crtsh': True,
'dns': True,
'virustotal': False,
'shodan': False
}
# Logging configuration
self.log_level = 'INFO'
self.log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# Flask configuration (shared)
self.flask_host = '127.0.0.1'
self.flask_port = 5000
self.flask_debug = True
def set_api_key(self, provider: str, api_key: str) -> bool:
"""
Set API key for a provider in this session.
Args:
provider: Provider name (shodan, virustotal)
api_key: API key string
Returns:
bool: True if key was set successfully
"""
if provider in self.api_keys:
self.api_keys[provider] = api_key
self.enabled_providers[provider] = True if api_key else False
return True
return False
def get_api_key(self, provider: str) -> Optional[str]:
"""
Get API key for a provider in this session.
Args:
provider: Provider name
Returns:
API key or None if not set
"""
return self.api_keys.get(provider)
def is_provider_enabled(self, provider: str) -> bool:
"""
Check if a provider is enabled in this session.
Args:
provider: Provider name
Returns:
bool: True if provider is enabled
"""
return self.enabled_providers.get(provider, False)
def get_rate_limit(self, provider: str) -> int:
"""
Get rate limit for a provider in this session.
Args:
provider: Provider name
Returns:
Rate limit in requests per minute
"""
return self.rate_limits.get(provider, 60)
def load_from_env(self):
"""Load configuration from environment variables (only if not already set)."""
if os.getenv('VIRUSTOTAL_API_KEY') and not self.api_keys['virustotal']:
self.set_api_key('virustotal', os.getenv('VIRUSTOTAL_API_KEY'))
if os.getenv('SHODAN_API_KEY') and not self.api_keys['shodan']:
self.set_api_key('shodan', os.getenv('SHODAN_API_KEY'))
# Override default settings from environment
self.default_recursion_depth = int(os.getenv('DEFAULT_RECURSION_DEPTH', '2'))
self.default_timeout = 30
self.max_concurrent_requests = 5
def create_session_config() -> SessionConfig:
"""Create a new session configuration instance."""
session_config = SessionConfig()
session_config.load_from_env()
return session_config

281
core/session_manager.py Normal file
View File

@ -0,0 +1,281 @@
"""
Session manager for DNSRecon multi-user support.
Manages individual scanner instances per user session with automatic cleanup.
"""
import threading
import time
import uuid
from typing import Dict, Optional, Any
from datetime import datetime, timezone
from core.scanner import Scanner
class SessionManager:
"""
Manages multiple scanner instances for concurrent user sessions.
Provides session isolation and automatic cleanup of inactive sessions.
"""
def __init__(self, session_timeout_minutes: int = 60):
"""
Initialize session manager.
Args:
session_timeout_minutes: Minutes of inactivity before session cleanup
"""
self.sessions: Dict[str, Dict[str, Any]] = {}
self.session_timeout = session_timeout_minutes * 60 # Convert to seconds
self.lock = threading.Lock()
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
print(f"SessionManager initialized with {session_timeout_minutes}min timeout")
def create_session(self) -> str:
"""
Create a new user session with dedicated scanner instance and configuration.
Enhanced with better debugging and race condition protection.
Returns:
Unique session ID
"""
session_id = str(uuid.uuid4())
print(f"=== CREATING SESSION {session_id} ===")
try:
# Create session-specific configuration
from core.session_config import create_session_config
session_config = create_session_config()
print(f"Created session config for {session_id}")
# Create scanner with session config
from core.scanner import Scanner
scanner_instance = Scanner(session_config=session_config)
print(f"Created scanner instance {id(scanner_instance)} for session {session_id}")
print(f"Initial scanner status: {scanner_instance.status}")
with self.lock:
self.sessions[session_id] = {
'scanner': scanner_instance,
'config': session_config,
'created_at': time.time(),
'last_activity': time.time(),
'user_agent': '',
'status': 'active'
}
print(f"Session {session_id} stored in session manager")
print(f"Total active sessions: {len([s for s in self.sessions.values() if s['status'] == 'active'])}")
print(f"=== SESSION {session_id} CREATED SUCCESSFULLY ===")
return session_id
except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}")
raise
def get_session(self, session_id: str) -> Optional[object]:
"""
Get scanner instance for a session with enhanced debugging.
Args:
session_id: Session identifier
Returns:
Scanner instance or None if session doesn't exist
"""
if not session_id:
print("get_session called with empty session_id")
return None
with self.lock:
if session_id not in self.sessions:
print(f"Session {session_id} not found in session manager")
print(f"Available sessions: {list(self.sessions.keys())}")
return None
session_data = self.sessions[session_id]
# Check if session is still active
if session_data['status'] != 'active':
print(f"Session {session_id} is not active (status: {session_data['status']})")
return None
# Update last activity
session_data['last_activity'] = time.time()
scanner = session_data['scanner']
print(f"Retrieved scanner {id(scanner)} for session {session_id}")
print(f"Scanner status: {scanner.status}")
return scanner
def get_or_create_session(self, session_id: Optional[str] = None) -> tuple[str, Scanner]:
"""
Get existing session or create new one.
Args:
session_id: Optional existing session ID
Returns:
Tuple of (session_id, scanner_instance)
"""
if session_id and self.get_session(session_id):
return session_id, self.get_session(session_id)
else:
new_session_id = self.create_session()
return new_session_id, self.get_session(new_session_id)
def terminate_session(self, session_id: str) -> bool:
"""
Terminate a specific session and cleanup resources.
Args:
session_id: Session to terminate
Returns:
True if session was terminated successfully
"""
with self.lock:
if session_id not in self.sessions:
return False
session_data = self.sessions[session_id]
scanner = session_data['scanner']
# Stop any running scan
try:
if scanner.status == 'running':
scanner.stop_scan()
print(f"Stopped scan for session: {session_id}")
except Exception as e:
print(f"Error stopping scan for session {session_id}: {e}")
# Mark as terminated
session_data['status'] = 'terminated'
session_data['terminated_at'] = time.time()
# Remove from active sessions after a brief delay to allow cleanup
threading.Timer(5.0, lambda: self._remove_session(session_id)).start()
print(f"Terminated session: {session_id}")
return True
def _remove_session(self, session_id: str) -> None:
"""Remove session from memory."""
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
print(f"Removed session from memory: {session_id}")
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Get session information without updating activity.
Args:
session_id: Session identifier
Returns:
Session information dictionary or None
"""
with self.lock:
if session_id not in self.sessions:
return None
session_data = self.sessions[session_id]
scanner = session_data['scanner']
return {
'session_id': session_id,
'created_at': datetime.fromtimestamp(session_data['created_at'], timezone.utc).isoformat(),
'last_activity': datetime.fromtimestamp(session_data['last_activity'], timezone.utc).isoformat(),
'status': session_data['status'],
'scan_status': scanner.status,
'current_target': scanner.current_target,
'uptime_seconds': time.time() - session_data['created_at']
}
def list_active_sessions(self) -> Dict[str, Dict[str, Any]]:
"""
List all active sessions with enhanced debugging info.
Returns:
Dictionary of session information
"""
active_sessions = {}
with self.lock:
for session_id, session_data in self.sessions.items():
if session_data['status'] == 'active':
scanner = session_data['scanner']
active_sessions[session_id] = {
'session_id': session_id,
'created_at': datetime.fromtimestamp(session_data['created_at'], timezone.utc).isoformat(),
'last_activity': datetime.fromtimestamp(session_data['last_activity'], timezone.utc).isoformat(),
'status': session_data['status'],
'scan_status': scanner.status,
'current_target': scanner.current_target,
'uptime_seconds': time.time() - session_data['created_at'],
'scanner_object_id': id(scanner)
}
return active_sessions
def _cleanup_loop(self) -> None:
"""Background thread to cleanup inactive sessions."""
while True:
try:
current_time = time.time()
sessions_to_cleanup = []
with self.lock:
for session_id, session_data in self.sessions.items():
if session_data['status'] != 'active':
continue
inactive_time = current_time - session_data['last_activity']
if inactive_time > self.session_timeout:
sessions_to_cleanup.append(session_id)
# Cleanup outside of lock to avoid deadlock
for session_id in sessions_to_cleanup:
print(f"Cleaning up inactive session: {session_id}")
self.terminate_session(session_id)
# Sleep for 5 minutes between cleanup cycles
time.sleep(300)
except Exception as e:
print(f"Error in session cleanup loop: {e}")
time.sleep(60) # Sleep for 1 minute on error
def get_statistics(self) -> Dict[str, Any]:
"""
Get session manager statistics.
Returns:
Statistics dictionary
"""
with self.lock:
active_count = sum(1 for s in self.sessions.values() if s['status'] == 'active')
running_scans = sum(1 for s in self.sessions.values()
if s['status'] == 'active' and s['scanner'].status == 'running')
return {
'total_sessions': len(self.sessions),
'active_sessions': active_count,
'running_scans': running_scans,
'session_timeout_minutes': self.session_timeout / 60
}
# Global session manager instance
session_manager = SessionManager(session_timeout_minutes=60)

View File

@ -7,10 +7,9 @@ import os
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from core.logger import get_forensic_logger
from core.graph_manager import NodeType, RelationshipType
from core.graph_manager import RelationshipType
class RateLimiter:
@ -42,36 +41,52 @@ class RateLimiter:
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Provides common functionality and defines the provider interface.
Now supports session-specific configuration.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
"""
Initialize base provider.
Initialize base provider with session-specific configuration.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit
rate_limit: Requests per minute limit (default override)
timeout: Request timeout in seconds
session_config: Session-specific configuration
"""
# Use session config if provided, otherwise fall back to global config
if session_config is not None:
self.config = session_config
actual_rate_limit = self.config.get_rate_limit(name)
actual_timeout = self.config.default_timeout
else:
# Fallback to global config for backwards compatibility
from config import config as global_config
self.config = global_config
actual_rate_limit = rate_limit
actual_timeout = timeout
self.name = name
self.rate_limiter = RateLimiter(rate_limit)
self.timeout = timeout
self.rate_limiter = RateLimiter(actual_rate_limit)
self.timeout = actual_timeout
self._local = threading.local()
self.logger = get_forensic_logger()
self._stop_event = None
# Caching configuration
self.cache_dir = '.cache'
# Caching configuration (per session)
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
self.cache_expiry = 12 * 3600 # 12 hours in seconds
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# Statistics
# Statistics (per provider instance)
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
@property
def session(self):
if not hasattr(self._local, 'session'):
@ -124,18 +139,13 @@ class BaseProvider(ABC):
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with forensic logging and retry logic.
Args:
url: Request URL
method: HTTP method
params: Query parameters
headers: Additional headers
target_indicator: The indicator being investigated
max_retries: Maximum number of retry attempts
Returns:
Response object or None if request failed
Now supports cancellation via stop_event from scanner.
"""
# Check for cancellation before starting
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled before start: {url}")
return None
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
@ -154,9 +164,22 @@ class BaseProvider(ABC):
return response
for attempt in range(max_retries + 1):
# Apply rate limiting
# Check for cancellation before each attempt
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
# Apply rate limiting (but reduce wait time if cancellation is requested)
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
break
self.rate_limiter.wait_if_needed()
# Check again after rate limiting
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled after rate limiting: {url}")
return None
start_time = time.time()
response = None
error = None
@ -171,20 +194,25 @@ class BaseProvider(ABC):
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# Use shorter timeout if termination is requested
request_timeout = self.timeout
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=self.timeout
timeout=request_timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=self.timeout
timeout=request_timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
@ -219,11 +247,27 @@ class BaseProvider(ABC):
self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
# Check for cancellation before retrying
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled, not retrying: {url}")
break
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
print(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time)
# Shorter backoff if termination is requested
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
backoff_time = min(0.5, backoff_time)
# Sleep with cancellation checking
sleep_start = time.time()
while time.time() - sleep_start < backoff_time:
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled during backoff: {url}")
return None
time.sleep(0.1) # Check every 100ms
continue
else:
break
@ -249,6 +293,15 @@ class BaseProvider(ABC):
return None
def set_stop_event(self, stop_event: threading.Event) -> None:
"""
Set the stop event for this provider to enable cancellation.
Args:
stop_event: Threading event to signal cancellation
"""
self._stop_event = stop_event
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
Determine if a request should be retried based on the exception.
@ -315,89 +368,3 @@ class BaseProvider(ABC):
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
}
def reset_statistics(self) -> None:
"""Reset provider statistics."""
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
def _extract_domain_from_url(self, url: str) -> Optional[str]:
"""
Extract domain from URL.
Args:
url: URL string
Returns:
Domain name or None if extraction fails
"""
try:
# Remove protocol
if '://' in url:
url = url.split('://', 1)[1]
# Remove path
if '/' in url:
url = url.split('/', 1)[0]
# Remove port
if ':' in url:
url = url.split(':', 1)[0]
return url.lower()
except Exception:
return None
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
Args:
domain: Domain string to validate
Returns:
True if domain appears valid
"""
if not domain or len(domain) > 253:
return False
# Check for valid characters and structure
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
Args:
ip: IP address string to validate
Returns:
True if IP appears valid
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False

View File

@ -1,6 +1,7 @@
"""
Certificate Transparency provider using crt.sh.
Discovers domain relationships through certificate SAN analysis.
Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking.
Stores certificates as metadata on domain nodes rather than creating certificate nodes.
"""
import json
@ -10,23 +11,26 @@ from urllib.parse import quote
from datetime import datetime, timezone
from .base_provider import BaseProvider
from utils.helpers import _is_valid_domain
from core.graph_manager import RelationshipType
class CrtShProvider(BaseProvider):
"""
Provider for querying crt.sh certificate transparency database.
Discovers domain relationships through certificate Subject Alternative Names (SANs).
Now uses session-specific configuration and caching.
"""
def __init__(self):
"""Initialize CrtSh provider with appropriate rate limiting."""
def __init__(self, session_config=None):
"""Initialize CrtSh provider with session-specific configuration."""
super().__init__(
name="crtsh",
rate_limit=60, # Be respectful to the free service
timeout=30
rate_limit=60,
timeout=15,
session_config=session_config
)
self.base_url = "https://crt.sh/"
self._stop_event = None
def get_name(self) -> str:
"""Return the provider name."""
@ -40,31 +44,128 @@ class CrtShProvider(BaseProvider):
"""
return True
def _parse_certificate_date(self, date_string: str) -> datetime:
"""
Parse certificate date from crt.sh format.
Args:
date_string: Date string from crt.sh API
Returns:
Parsed datetime object in UTC
"""
if not date_string:
raise ValueError("Empty date string")
try:
# Handle various possible formats from crt.sh
if date_string.endswith('Z'):
return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc)
elif '+' in date_string or date_string.endswith('UTC'):
# Handle timezone-aware strings
date_string = date_string.replace('UTC', '').strip()
if '+' in date_string:
date_string = date_string.split('+')[0]
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
else:
# Assume UTC if no timezone specified
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
except Exception as e:
# Fallback: try parsing without timezone info and assume UTC
try:
return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc)
except Exception:
raise ValueError(f"Unable to parse date: {date_string}") from e
def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool:
"""Check if a certificate is currently valid."""
"""
Check if a certificate is currently valid based on its expiry date.
Args:
cert_data: Certificate data from crt.sh
Returns:
True if certificate is currently valid (not expired)
"""
try:
not_after_str = cert_data.get('not_after')
if not_after_str:
# Append 'Z' to indicate UTC if it's not present
if not not_after_str.endswith('Z'):
not_after_str += 'Z'
not_after_date = datetime.fromisoformat(not_after_str.replace('Z', '+00:00'))
return not_after_date > datetime.now(timezone.utc)
except Exception:
if not not_after_str:
return False
not_after_date = self._parse_certificate_date(not_after_str)
not_before_str = cert_data.get('not_before')
now = datetime.now(timezone.utc)
# Check if certificate is within valid date range
is_not_expired = not_after_date > now
if not_before_str:
not_before_date = self._parse_certificate_date(not_before_str)
is_not_before_valid = not_before_date <= now
return is_not_expired and is_not_before_valid
return is_not_expired
except Exception as e:
self.logger.logger.debug(f"Certificate validity check failed: {e}")
return False
def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract comprehensive metadata from certificate data.
Args:
cert_data: Raw certificate data from crt.sh
Returns:
Comprehensive certificate metadata dictionary
"""
metadata = {
'certificate_id': cert_data.get('id'),
'serial_number': cert_data.get('serial_number'),
'issuer_name': cert_data.get('issuer_name'),
'issuer_ca_id': cert_data.get('issuer_ca_id'),
'common_name': cert_data.get('common_name'),
'not_before': cert_data.get('not_before'),
'not_after': cert_data.get('not_after'),
'entry_timestamp': cert_data.get('entry_timestamp'),
'source': 'crt.sh'
}
# Add computed fields
try:
if metadata['not_before'] and metadata['not_after']:
not_before = self._parse_certificate_date(metadata['not_before'])
not_after = self._parse_certificate_date(metadata['not_after'])
metadata['validity_period_days'] = (not_after - not_before).days
metadata['is_currently_valid'] = self._is_cert_valid(cert_data)
metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30
# Add human-readable dates
metadata['not_before_formatted'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC')
metadata['not_after_formatted'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC')
except Exception as e:
self.logger.logger.debug(f"Error computing certificate metadata: {e}")
metadata['is_currently_valid'] = False
metadata['expires_soon'] = False
return metadata
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the domain.
Args:
domain: Domain to investigate
Returns:
List of relationships discovered from certificate analysis
Creates domain-to-domain relationships and stores certificate data as metadata.
Now supports early termination via stop_event.
"""
if not self._is_valid_domain(domain):
if not _is_valid_domain(domain):
return []
# Check for cancellation before starting
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before start for domain: {domain}")
return []
relationships = []
@ -72,56 +173,113 @@ class CrtShProvider(BaseProvider):
try:
# Query crt.sh for certificates
url = f"{self.base_url}?q={quote(domain)}&output=json"
response = self.make_request(url, target_indicator=domain)
response = self.make_request(url, target_indicator=domain, max_retries=1) # Reduce retries for faster cancellation
if not response or response.status_code != 200:
return []
# Check for cancellation after request
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled after request for domain: {domain}")
return []
certificates = response.json()
if not certificates:
return []
# Process certificates to extract relationships
discovered_subdomains = {}
# Check for cancellation before processing
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before processing for domain: {domain}")
return []
for cert_data in certificates:
# Aggregate certificate data by domain
domain_certificates = {}
all_discovered_domains = set()
# Process certificates and group by domain (with cancellation checks)
for i, cert_data in enumerate(certificates):
# Check for cancellation every 10 certificates
if i % 10 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}")
break
cert_metadata = self._extract_certificate_metadata(cert_data)
cert_domains = self._extract_domains_from_certificate(cert_data)
is_valid = self._is_cert_valid(cert_data)
for subdomain in cert_domains:
if self._is_valid_domain(subdomain) and subdomain != domain:
if subdomain not in discovered_subdomains:
discovered_subdomains[subdomain] = {'has_valid_cert': False, 'issuers': set()}
# Add all domains from this certificate to our tracking
for cert_domain in cert_domains:
if not _is_valid_domain(cert_domain):
continue
if is_valid:
discovered_subdomains[subdomain]['has_valid_cert'] = True
all_discovered_domains.add(cert_domain)
issuer = cert_data.get('issuer_name')
if issuer:
discovered_subdomains[subdomain]['issuers'].add(issuer)
# Initialize domain certificate list if needed
if cert_domain not in domain_certificates:
domain_certificates[cert_domain] = []
# Create relationships from the discovered subdomains
for subdomain, data in discovered_subdomains.items():
raw_data = {
'has_valid_cert': data['has_valid_cert'],
'issuers': list(data['issuers']),
'source': 'crt.sh'
# Add this certificate to the domain's certificate list
domain_certificates[cert_domain].append(cert_metadata)
# Final cancellation check before creating relationships
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before relationship creation for domain: {domain}")
return []
# Create relationships from query domain to ALL discovered domains
for discovered_domain in all_discovered_domains:
if discovered_domain == domain:
continue # Skip self-relationships
# Check for cancellation during relationship creation
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh relationship creation cancelled for domain: {domain}")
break
if not _is_valid_domain(discovered_domain):
continue
# Get certificates for both domains
query_domain_certs = domain_certificates.get(domain, [])
discovered_domain_certs = domain_certificates.get(discovered_domain, [])
# Find shared certificates (for metadata purposes)
shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs)
# Calculate confidence based on relationship type and shared certificates
confidence = self._calculate_domain_relationship_confidence(
domain, discovered_domain, shared_certificates, all_discovered_domains
)
# Create comprehensive raw data for the relationship
relationship_raw_data = {
'relationship_type': 'certificate_discovery',
'shared_certificates': shared_certificates,
'total_shared_certs': len(shared_certificates),
'discovery_context': self._determine_relationship_context(discovered_domain, domain),
'domain_certificates': {
domain: self._summarize_certificates(query_domain_certs),
discovered_domain: self._summarize_certificates(discovered_domain_certs)
}
}
# Create domain -> domain relationship
relationships.append((
domain,
subdomain,
discovered_domain,
RelationshipType.SAN_CERTIFICATE,
RelationshipType.SAN_CERTIFICATE.default_confidence,
raw_data
confidence,
relationship_raw_data
))
# Log the relationship discovery
self.log_relationship_discovery(
source_node=domain,
target_node=subdomain,
target_node=discovered_domain,
relationship_type=RelationshipType.SAN_CERTIFICATE,
confidence_score=RelationshipType.SAN_CERTIFICATE.default_confidence,
raw_data=raw_data,
discovery_method="certificate_san_analysis"
confidence_score=confidence,
raw_data=relationship_raw_data,
discovery_method="certificate_transparency_analysis"
)
except json.JSONDecodeError as e:
@ -131,6 +289,165 @@ class CrtShProvider(BaseProvider):
return relationships
def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Find certificates that are shared between two domain certificate lists.
Args:
certs1: First domain's certificates
certs2: Second domain's certificates
Returns:
List of shared certificate metadata
"""
shared = []
# Create a set of certificate IDs from the first list for quick lookup
cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')}
# Find certificates in the second list that match
for cert in certs2:
if cert.get('certificate_id') in cert1_ids:
shared.append(cert)
return shared
def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Create a summary of certificates for a domain.
Args:
certificates: List of certificate metadata
Returns:
Summary dictionary with aggregate statistics
"""
if not certificates:
return {
'total_certificates': 0,
'valid_certificates': 0,
'expired_certificates': 0,
'expires_soon_count': 0,
'unique_issuers': [],
'latest_certificate': None,
'has_valid_cert': False
}
valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid'))
expired_count = len(certificates) - valid_count
expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon'))
# Get unique issuers
unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name')))
# Find the most recent certificate
latest_cert = None
latest_date = None
for cert in certificates:
try:
if cert.get('not_before'):
cert_date = self._parse_certificate_date(cert['not_before'])
if latest_date is None or cert_date > latest_date:
latest_date = cert_date
latest_cert = cert
except Exception:
continue
return {
'total_certificates': len(certificates),
'valid_certificates': valid_count,
'expired_certificates': expired_count,
'expires_soon_count': expires_soon_count,
'unique_issuers': unique_issuers,
'latest_certificate': latest_cert,
'has_valid_cert': valid_count > 0,
'certificate_details': certificates # Full details for forensic analysis
}
def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str,
shared_certificates: List[Dict[str, Any]],
all_discovered_domains: Set[str]) -> float:
"""
Calculate confidence score for domain relationship based on various factors.
Args:
domain1: Source domain (query domain)
domain2: Target domain (discovered domain)
shared_certificates: List of shared certificate metadata
all_discovered_domains: All domains discovered in this query
Returns:
Confidence score between 0.0 and 1.0
"""
base_confidence = RelationshipType.SAN_CERTIFICATE.default_confidence
# Adjust confidence based on domain relationship context
relationship_context = self._determine_relationship_context(domain2, domain1)
if relationship_context == 'exact_match':
context_bonus = 0.0 # This shouldn't happen, but just in case
elif relationship_context == 'subdomain':
context_bonus = 0.1 # High confidence for subdomains
elif relationship_context == 'parent_domain':
context_bonus = 0.05 # Medium confidence for parent domains
else:
context_bonus = 0.0 # Related domains get base confidence
# Adjust confidence based on shared certificates
if shared_certificates:
shared_count = len(shared_certificates)
if shared_count >= 3:
shared_bonus = 0.1
elif shared_count >= 2:
shared_bonus = 0.05
else:
shared_bonus = 0.02
# Additional bonus for valid shared certificates
valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid'))
if valid_shared > 0:
validity_bonus = 0.05
else:
validity_bonus = 0.0
else:
# Even without shared certificates, domains found in the same query have some relationship
shared_bonus = 0.0
validity_bonus = 0.0
# Adjust confidence based on certificate issuer reputation (if shared certificates exist)
issuer_bonus = 0.0
if shared_certificates:
for cert in shared_certificates:
issuer = cert.get('issuer_name', '').lower()
if any(trusted_ca in issuer for trusted_ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']):
issuer_bonus = max(issuer_bonus, 0.03)
break
# Calculate final confidence
final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus
return max(0.1, min(1.0, final_confidence)) # Clamp between 0.1 and 1.0
def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str:
"""
Determine the context of the relationship between certificate domain and query domain.
Args:
cert_domain: Domain found in certificate
query_domain: Original query domain
Returns:
String describing the relationship context
"""
if cert_domain == query_domain:
return 'exact_match'
elif cert_domain.endswith(f'.{query_domain}'):
return 'subdomain'
elif query_domain.endswith(f'.{cert_domain}'):
return 'parent_domain'
else:
return 'related_domain'
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the IP address.
@ -143,7 +460,6 @@ class CrtShProvider(BaseProvider):
Empty list (crt.sh doesn't support IP-based certificate queries effectively)
"""
# crt.sh doesn't effectively support IP-based certificate queries
# This would require parsing certificate details for IP SANs, which is complex
return []
def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]:
@ -162,7 +478,7 @@ class CrtShProvider(BaseProvider):
common_name = cert_data.get('common_name', '')
if common_name:
cleaned_cn = self._clean_domain_name(common_name)
if cleaned_cn and self._is_valid_domain(cleaned_cn):
if cleaned_cn and _is_valid_domain(cleaned_cn):
domains.add(cleaned_cn)
# Extract from name_value field (contains SANs)
@ -171,7 +487,7 @@ class CrtShProvider(BaseProvider):
# Split by newlines and clean each domain
for line in name_value.split('\n'):
cleaned_domain = self._clean_domain_name(line.strip())
if cleaned_domain and self._is_valid_domain(cleaned_domain):
if cleaned_domain and _is_valid_domain(cleaned_domain):
domains.add(cleaned_domain)
return domains
@ -216,69 +532,3 @@ class CrtShProvider(BaseProvider):
return domain
return ""
def get_certificate_details(self, certificate_id: str) -> Dict[str, Any]:
"""
Get detailed information about a specific certificate.
Args:
certificate_id: Certificate ID from crt.sh
Returns:
Dictionary containing certificate details
"""
try:
url = f"{self.base_url}?id={certificate_id}&output=json"
response = self.make_request(url, target_indicator=f"cert_{certificate_id}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error fetching certificate details for {certificate_id}: {e}")
return {}
def search_certificates_by_serial(self, serial_number: str) -> List[Dict[str, Any]]:
"""
Search for certificates by serial number.
Args:
serial_number: Certificate serial number
Returns:
List of matching certificates
"""
try:
url = f"{self.base_url}?serial={quote(serial_number)}&output=json"
response = self.make_request(url, target_indicator=f"serial_{serial_number}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error searching certificates by serial {serial_number}: {e}")
return []
def get_issuer_certificates(self, issuer_name: str) -> List[Dict[str, Any]]:
"""
Get certificates issued by a specific CA.
Args:
issuer_name: Certificate Authority name
Returns:
List of certificates from the specified issuer
"""
try:
url = f"{self.base_url}?issuer={quote(issuer_name)}&output=json"
response = self.make_request(url, target_indicator=f"issuer_{issuer_name}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error fetching certificates for issuer {issuer_name}: {e}")
return []

View File

@ -1,25 +1,26 @@
# dnsrecon/providers/dns_provider.py
import socket
import dns.resolver
import dns.reversename
from typing import List, Dict, Any, Tuple, Optional
from typing import List, Dict, Any, Tuple
from .base_provider import BaseProvider
from core.graph_manager import RelationshipType, NodeType
from utils.helpers import _is_valid_ip, _is_valid_domain
from core.graph_manager import RelationshipType
class DNSProvider(BaseProvider):
"""
Provider for standard DNS resolution and reverse DNS lookups.
Discovers domain-to-IP and IP-to-domain relationships through DNS records.
Now uses session-specific configuration.
"""
def __init__(self):
"""Initialize DNS provider with appropriate rate limiting."""
def __init__(self, session_config=None):
"""Initialize DNS provider with session-specific configuration."""
super().__init__(
name="dns",
rate_limit=100, # DNS queries can be faster
timeout=10
rate_limit=100,
timeout=10,
session_config=session_config
)
# Configure DNS resolver
@ -45,7 +46,7 @@ class DNSProvider(BaseProvider):
Returns:
List of relationships discovered from DNS analysis
"""
if not self._is_valid_domain(domain):
if not _is_valid_domain(domain):
return []
relationships = []
@ -66,7 +67,7 @@ class DNSProvider(BaseProvider):
Returns:
List of relationships discovered from reverse DNS
"""
if not self._is_valid_ip(ip):
if not _is_valid_ip(ip):
return []
relationships = []
@ -81,7 +82,7 @@ class DNSProvider(BaseProvider):
for ptr_record in response:
hostname = str(ptr_record).rstrip('.')
if self._is_valid_domain(hostname):
if _is_valid_domain(hostname):
raw_data = {
'query_type': 'PTR',
'ip_address': ip,

View File

@ -4,38 +4,37 @@ Discovers IP relationships and infrastructure context through Shodan API.
"""
import json
from typing import List, Dict, Any, Tuple, Optional
from urllib.parse import quote
from typing import List, Dict, Any, Tuple
from .base_provider import BaseProvider
from utils.helpers import _is_valid_ip, _is_valid_domain
from core.graph_manager import RelationshipType
from config import config
class ShodanProvider(BaseProvider):
"""
Provider for querying Shodan API for IP address and hostname information.
Requires valid API key and respects Shodan's rate limits.
Now uses session-specific API keys.
"""
def __init__(self):
"""Initialize Shodan provider with appropriate rate limiting."""
def __init__(self, session_config=None):
"""Initialize Shodan provider with session-specific configuration."""
super().__init__(
name="shodan",
rate_limit=60, # Shodan API has various rate limits depending on plan
timeout=30
rate_limit=60,
timeout=30,
session_config=session_config
)
self.base_url = "https://api.shodan.io"
self.api_key = config.get_api_key('shodan')
self.api_key = self.config.get_api_key('shodan')
def is_available(self) -> bool:
"""Check if Shodan provider is available (has valid API key in this session)."""
return self.api_key is not None and len(self.api_key.strip()) > 0
def get_name(self) -> str:
"""Return the provider name."""
return "shodan"
def is_available(self) -> bool:
"""
Check if Shodan provider is available (has valid API key).
"""
return self.api_key is not None and len(self.api_key.strip()) > 0
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
@ -48,7 +47,7 @@ class ShodanProvider(BaseProvider):
Returns:
List of relationships discovered from Shodan data
"""
if not self._is_valid_domain(domain) or not self.is_available():
if not _is_valid_domain(domain) or not self.is_available():
return []
relationships = []
@ -109,7 +108,7 @@ class ShodanProvider(BaseProvider):
# Also create relationships to other hostnames on the same IP
for hostname in hostnames:
if hostname != domain and self._is_valid_domain(hostname):
if hostname != domain and _is_valid_domain(hostname):
hostname_raw_data = {
'shared_ip': ip_address,
'all_hostnames': hostnames,
@ -150,7 +149,7 @@ class ShodanProvider(BaseProvider):
Returns:
List of relationships discovered from Shodan IP data
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return []
relationships = []
@ -170,7 +169,7 @@ class ShodanProvider(BaseProvider):
# Extract hostname relationships
hostnames = data.get('hostnames', [])
for hostname in hostnames:
if self._is_valid_domain(hostname):
if _is_valid_domain(hostname):
raw_data = {
'ip_address': ip,
'hostname': hostname,
@ -280,7 +279,7 @@ class ShodanProvider(BaseProvider):
Returns:
List of service information dictionaries
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return []
try:

View File

@ -4,38 +4,37 @@ Discovers domain relationships through passive DNS and URL analysis.
"""
import json
from typing import List, Dict, Any, Tuple, Optional
from typing import List, Dict, Any, Tuple
from .base_provider import BaseProvider
from utils.helpers import _is_valid_ip, _is_valid_domain
from core.graph_manager import RelationshipType
from config import config
class VirusTotalProvider(BaseProvider):
"""
Provider for querying VirusTotal API for passive DNS and domain reputation data.
Requires valid API key and strictly respects free tier rate limits.
Now uses session-specific API keys and rate limits.
"""
def __init__(self):
"""Initialize VirusTotal provider with strict rate limiting for free tier."""
def __init__(self, session_config=None):
"""Initialize VirusTotal provider with session-specific configuration."""
super().__init__(
name="virustotal",
rate_limit=4, # Free tier: 4 requests per minute
timeout=30
timeout=30,
session_config=session_config
)
self.base_url = "https://www.virustotal.com/vtapi/v2"
self.api_key = config.get_api_key('virustotal')
self.api_key = self.config.get_api_key('virustotal')
def is_available(self) -> bool:
"""Check if VirusTotal provider is available (has valid API key in this session)."""
return self.api_key is not None and len(self.api_key.strip()) > 0
def get_name(self) -> str:
"""Return the provider name."""
return "virustotal"
def is_available(self) -> bool:
"""
Check if VirusTotal provider is available (has valid API key).
"""
return self.api_key is not None and len(self.api_key.strip()) > 0
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query VirusTotal for domain information including passive DNS.
@ -46,7 +45,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
List of relationships discovered from VirusTotal data
"""
if not self._is_valid_domain(domain) or not self.is_available():
if not _is_valid_domain(domain) or not self.is_available():
return []
relationships = []
@ -71,7 +70,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
List of relationships discovered from VirusTotal IP data
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return []
relationships = []
@ -114,7 +113,7 @@ class VirusTotalProvider(BaseProvider):
ip_address = resolution.get('ip_address')
last_resolved = resolution.get('last_resolved')
if ip_address and self._is_valid_ip(ip_address):
if ip_address and _is_valid_ip(ip_address):
raw_data = {
'domain': domain,
'ip_address': ip_address,
@ -142,7 +141,7 @@ class VirusTotalProvider(BaseProvider):
# Extract subdomains
subdomains = data.get('subdomains', [])
for subdomain in subdomains:
if subdomain != domain and self._is_valid_domain(subdomain):
if subdomain != domain and _is_valid_domain(subdomain):
raw_data = {
'parent_domain': domain,
'subdomain': subdomain,
@ -200,7 +199,7 @@ class VirusTotalProvider(BaseProvider):
hostname = resolution.get('hostname')
last_resolved = resolution.get('last_resolved')
if hostname and self._is_valid_domain(hostname):
if hostname and _is_valid_domain(hostname):
raw_data = {
'ip_address': ip,
'hostname': hostname,
@ -254,7 +253,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
Dictionary containing reputation data
"""
if not self._is_valid_domain(domain) or not self.is_available():
if not _is_valid_domain(domain) or not self.is_available():
return {}
try:
@ -293,7 +292,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
Dictionary containing reputation data
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return {}
try:

View File

@ -318,17 +318,13 @@ input[type="text"]:focus, select:focus {
}
.graph-container {
height: 500px;
height: 800px;
position: relative;
background-color: #1a1a1a;
border-top: 1px solid #444;
transition: height 0.3s ease;
}
.graph-container.expanded {
height: 700px;
}
.graph-controls {
position: absolute;
top: 10px;
@ -535,29 +531,6 @@ input[type="text"]:focus, select:focus {
box-shadow: 0 4px 6px rgba(0,0,0,0.3);
}
.node-info-title {
color: #00ff41;
font-weight: bold;
margin-bottom: 0.5rem;
border-bottom: 1px solid #444;
padding-bottom: 0.25rem;
}
.node-info-detail {
margin-bottom: 0.25rem;
display: flex;
justify-content: space-between;
}
.node-info-label {
color: #999;
}
.node-info-value {
color: #c7c7c7;
font-weight: 500;
}
/* Footer */
.footer {
background-color: #0a0a0a;

View File

@ -233,7 +233,6 @@ class GraphManager {
const nodeId = params.node;
const node = this.nodes.get(nodeId);
if (node) {
this.showNodeInfoPopup(params.pointer.DOM, node);
this.highlightConnectedNodes(nodeId, true);
}
});
@ -243,19 +242,6 @@ class GraphManager {
this.clearHoverHighlights();
});
// Edge hover events
this.network.on('hoverEdge', (params) => {
const edgeId = params.edge;
const edge = this.edges.get(edgeId);
if (edge) {
this.showEdgeInfo(params.pointer.DOM, edge);
}
});
this.network.on('blurEdge', () => {
this.hideNodeInfoPopup();
});
// Double-click to focus on node
this.network.on('doubleClick', (params) => {
if (params.nodes.length > 0) {
@ -347,7 +333,6 @@ class GraphManager {
const processedNode = {
id: node.id,
label: this.formatNodeLabel(node.id, node.type),
title: this.createNodeTooltip(node),
color: this.getNodeColor(node.type),
size: this.getNodeSize(node.type),
borderColor: this.getNodeBorderColor(node.type),
@ -373,11 +358,14 @@ class GraphManager {
}
// Style based on certificate validity
if (node.has_valid_cert === true) {
processedNode.borderColor = '#00ff41'; // Green for valid cert
} else if (node.has_valid_cert === false) {
processedNode.borderColor = '#ff9900'; // Amber for expired/no cert
processedNode.borderDashes = [5, 5];
if (node.type === 'domain') {
if (node.metadata && node.metadata.has_valid_cert === true) {
processedNode.color = '#00ff41'; // Bright green for valid cert
processedNode.borderColor = '#00aa2e';
} else if (node.metadata && node.metadata.has_valid_cert === false) {
processedNode.color = '#888888'; // Muted grey color
processedNode.borderColor = '#666666'; // Darker grey border
}
}
return processedNode;
@ -457,9 +445,9 @@ class GraphManager {
const colors = {
'domain': '#00ff41', // Green
'ip': '#ff9900', // Amber
'certificate': '#c7c7c7', // Gray
'asn': '#00aaff', // Blue
'large_entity': '#ff6b6b' // Red for large entities
'large_entity': '#ff6b6b', // Red for large entities
'dns_record': '#999999'
};
return colors[nodeType] || '#ffffff';
}
@ -474,8 +462,8 @@ class GraphManager {
const borderColors = {
'domain': '#00aa2e',
'ip': '#cc7700',
'certificate': '#999999',
'asn': '#0088cc'
'asn': '#0088cc',
'dns_record': '#999999'
};
return borderColors[nodeType] || '#666666';
}
@ -489,8 +477,8 @@ class GraphManager {
const sizes = {
'domain': 12,
'ip': 14,
'certificate': 10,
'asn': 16
'asn': 16,
'dns_record': 8
};
return sizes[nodeType] || 12;
}
@ -504,8 +492,8 @@ class GraphManager {
const shapes = {
'domain': 'dot',
'ip': 'square',
'certificate': 'diamond',
'asn': 'triangle'
'asn': 'triangle',
'dns_record': 'hexagon'
};
return shapes[nodeType] || 'dot';
}
@ -541,26 +529,7 @@ class GraphManager {
}
/**
* Create node tooltip
* @param {Object} node - Node data
* @returns {string} HTML tooltip content
*/
createNodeTooltip(node) {
let tooltip = `<div style="font-family: 'Roboto Mono', monospace; font-size: 11px;">`;
tooltip += `<div style="color: #00ff41; font-weight: bold; margin-bottom: 4px;">${node.id}</div>`;
tooltip += `<div style="color: #999; margin-bottom: 2px;">Type: ${node.type}</div>`;
if (node.metadata && Object.keys(node.metadata).length > 0) {
tooltip += `<div style="color: #999; margin-top: 4px; border-top: 1px solid #444; padding-top: 4px;">`;
tooltip += `Click for details</div>`;
}
tooltip += `</div>`;
return tooltip;
}
/**
* Create edge tooltip
* Create edge tooltip with correct provider information
* @param {Object} edge - Edge data
* @returns {string} HTML tooltip content
*/
@ -570,7 +539,7 @@ class GraphManager {
tooltip += `<div style="color: #999; margin-bottom: 2px;">Confidence: ${(edge.confidence_score * 100).toFixed(1)}%</div>`;
if (edge.source_provider) {
tooltip += `<div style="color: #999; margin-bottom: 2px;">Source: ${edge.source_provider}</div>`;
tooltip += `<div style="color: #999; margin-bottom: 2px;">Provider: ${edge.source_provider}</div>`;
}
if (edge.discovery_timestamp) {
@ -610,69 +579,6 @@ class GraphManager {
document.dispatchEvent(event);
}
/**
* Show enhanced node info popup
* @param {Object} position - Mouse position
* @param {Object} node - Node data
*/
showNodeInfoPopup(position, node) {
if (!this.nodeInfoPopup) return;
const html = `
<div class="node-info-title">${node.id}</div>
<div class="node-info-detail">
<span class="node-info-label">Type:</span>
<span class="node-info-value">${node.type || 'Unknown'}</span>
</div>
${node.metadata && Object.keys(node.metadata).length > 0 ?
'<div class="node-info-detail"><span class="node-info-label">Details:</span><span class="node-info-value">Click for more</span></div>' :
''}
`;
this.nodeInfoPopup.innerHTML = html;
this.nodeInfoPopup.style.display = 'block';
this.nodeInfoPopup.style.left = position.x + 15 + 'px';
this.nodeInfoPopup.style.top = position.y - 10 + 'px';
// Ensure popup stays in viewport
const rect = this.nodeInfoPopup.getBoundingClientRect();
if (rect.right > window.innerWidth) {
this.nodeInfoPopup.style.left = position.x - rect.width - 15 + 'px';
}
if (rect.bottom > window.innerHeight) {
this.nodeInfoPopup.style.top = position.y - rect.height + 10 + 'px';
}
}
/**
* Show edge information tooltip
* @param {Object} position - Mouse position
* @param {Object} edge - Edge data
*/
showEdgeInfo(position, edge) {
if (!this.nodeInfoPopup) return;
const confidence = edge.metadata ? edge.metadata.confidence_score : 0;
const provider = edge.metadata ? edge.metadata.source_provider : 'Unknown';
const html = `
<div class="node-info-title">${edge.metadata ? edge.metadata.relationship_type : 'Relationship'}</div>
<div class="node-info-detail">
<span class="node-info-label">Confidence:</span>
<span class="node-info-value">${(confidence * 100).toFixed(1)}%</span>
</div>
<div class="node-info-detail">
<span class="node-info-label">Provider:</span>
<span class="node-info-value">${provider}</span>
</div>
`;
this.nodeInfoPopup.innerHTML = html;
this.nodeInfoPopup.style.display = 'block';
this.nodeInfoPopup.style.left = position.x + 15 + 'px';
this.nodeInfoPopup.style.top = position.y - 10 + 'px';
}
/**
* Hide node info popup
*/

View File

@ -511,9 +511,18 @@ class DNSReconApp {
}
}
// Update session ID
if (this.currentSessionId && this.elements.sessionId) {
this.elements.sessionId.textContent = `Session: ${this.currentSessionId}`;
// Update session ID display with user session info
if (this.elements.sessionId) {
const scanSessionId = this.currentSessionId;
const userSessionId = status.user_session_id;
if (scanSessionId && userSessionId) {
this.elements.sessionId.textContent = `Session: ${userSessionId.substring(0, 8)}... | Scan: ${scanSessionId}`;
} else if (userSessionId) {
this.elements.sessionId.textContent = `User Session: ${userSessionId.substring(0, 8)}...`;
} else {
this.elements.sessionId.textContent = 'Session: Loading...';
}
}
console.log('Status display updated successfully');
@ -730,11 +739,13 @@ class DNSReconApp {
}
let detailsHtml = '';
const createDetailRow = (label, value) => {
const createDetailRow = (label, value, statusIcon = '') => {
const baseId = `detail-${label.replace(/[^a-zA-Z0-9]/g, '-')}`;
// Handle empty or undefined values by showing N/A
if (value === null || value === undefined || (Array.isArray(value) && value.length === 0)) {
// Handle empty or undefined values
if (value === null || value === undefined ||
(Array.isArray(value) && value.length === 0) ||
(typeof value === 'object' && Object.keys(value).length === 0)) {
return `
<div class="detail-row">
<span class="detail-label">${label} <span class="status-icon text-warning"></span></span>
@ -743,12 +754,11 @@ class DNSReconApp {
`;
}
// If value is an array, create a row for each item
// Handle arrays
if (Array.isArray(value)) {
return value.map((item, index) => {
const itemId = `${baseId}-${index}`;
// Only show the label for the first item in the list
const itemLabel = index === 0 ? label : '';
const itemLabel = index === 0 ? `${label} <span class="status-icon text-success">✓</span>` : '';
return `
<div class="detail-row">
<span class="detail-label">${itemLabel}</span>
@ -758,12 +768,13 @@ class DNSReconApp {
`;
}).join('');
}
// Handle objects and other primitive values in a single row
// Handle objects and primitives
else {
const valueId = `${baseId}-0`;
const icon = statusIcon || '<span class="status-icon text-success">✓</span>';
return `
<div class="detail-row">
<span class="detail-label">${label} <span class="status-icon text-success"></span></span>
<span class="detail-label">${label} ${icon}</span>
<span class="detail-value" id="${valueId}">${this.formatValue(value)}</span>
<button class="copy-btn" onclick="copyToClipboard('${valueId}')" title="Copy">📋</button>
</div>
@ -773,33 +784,44 @@ class DNSReconApp {
const metadata = node.metadata || {};
// Display data based on node type
switch (node.type) {
case 'domain':
detailsHtml += createDetailRow('DNS Records', metadata.dns_records);
detailsHtml += createDetailRow('Related Domains (SAN)', metadata.related_domains_san);
detailsHtml += createDetailRow('Certificate Data', metadata.certificate_data);
detailsHtml += createDetailRow('Passive DNS', metadata.passive_dns);
detailsHtml += createDetailRow('Shodan Data', metadata.shodan);
detailsHtml += createDetailRow('VirusTotal Data', metadata.virustotal);
detailsHtml += createDetailRow('ASN Information', metadata.asn_data);
break;
case 'ip':
detailsHtml += createDetailRow('DNS Records', metadata.dns_records);
detailsHtml += createDetailRow('Passive DNS', metadata.passive_dns);
detailsHtml += createDetailRow('Shodan Data', metadata.shodan);
detailsHtml += createDetailRow('VirusTotal Data', metadata.virustotal);
break;
case 'certificate':
detailsHtml += createDetailRow('Certificate Hash', metadata.hash);
detailsHtml += createDetailRow('SANs', metadata.sans);
detailsHtml += createDetailRow('Issuer', metadata.issuer);
detailsHtml += createDetailRow('Validity', `From: ${metadata.not_before || 'N/A'} To: ${metadata.not_after || 'N/A'}`);
detailsHtml += createDetailRow('ASN Information', metadata.asn_data);
break;
case 'asn':
detailsHtml += createDetailRow('ASN', metadata.asn);
detailsHtml += createDetailRow('Description', metadata.description);
detailsHtml += createDetailRow('ASN Information', metadata.asn_data);
detailsHtml += createDetailRow('Related IPs', metadata.passive_dns);
break;
case 'large_entity':
detailsHtml += createDetailRow('Entity Type', metadata.entity_type || 'Large Collection');
detailsHtml += createDetailRow('Item Count', metadata.count);
detailsHtml += createDetailRow('Discovered Domains', metadata.domains);
break;
}
// Add any additional metadata not covered above
for (const [key, value] of Object.entries(metadata)) {
if (!['dns_records', 'related_domains_san', 'certificate_data', 'passive_dns',
'shodan', 'virustotal', 'asn_data', 'hash', 'sans', 'issuer',
'not_before', 'not_after', 'entity_type', 'count', 'domains'].includes(key)) {
detailsHtml += createDetailRow(this.formatLabel(key), value);
}
}
if (this.elements.modalDetails) {
this.elements.modalDetails.innerHTML = detailsHtml;
}

View File

@ -137,6 +137,10 @@
<div class="legend-color" style="background-color: #c7c7c7;"></div>
<span>Certificates</span>
</div>
<div class="legend-item">
<div class="legend-color" style="background-color: #9d4edd;"></div>
<span>DNS Records</span>
</div>
<div class="legend-item">
<div class="legend-edge high-confidence"></div>
<span>High Confidence</span>

0
utils/__init__.py Normal file
View File

50
utils/helpers.py Normal file
View File

@ -0,0 +1,50 @@
def _is_valid_domain(domain: str) -> bool:
"""
Basic domain validation.
Args:
domain: Domain string to validate
Returns:
True if domain appears valid
"""
if not domain or len(domain) > 253:
return False
# Check for valid characters and structure
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(ip: str) -> bool:
"""
Basic IP address validation.
Args:
ip: IP address string to validate
Returns:
True if IP appears valid
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False