This commit is contained in:
overcuriousity
2025-09-11 14:01:15 +02:00
parent 2d485c5703
commit d3e1fcf35f
18 changed files with 1806 additions and 843 deletions

View File

@@ -5,8 +5,10 @@ Phase 2: Enhanced with concurrent processing and real-time capabilities.
"""
from .graph_manager import GraphManager, NodeType, RelationshipType
from .scanner import Scanner, ScanStatus, scanner
from .scanner import Scanner, ScanStatus # Remove 'scanner' global instance
from .logger import ForensicLogger, get_forensic_logger, new_session
from .session_manager import session_manager # Add session manager
from .session_config import SessionConfig, create_session_config # Add session config
__all__ = [
'GraphManager',
@@ -14,10 +16,13 @@ __all__ = [
'RelationshipType',
'Scanner',
'ScanStatus',
'scanner',
# 'scanner', # Remove this - no more global scanner
'ForensicLogger',
'get_forensic_logger',
'new_session'
'new_session',
'session_manager', # Add this
'SessionConfig', # Add this
'create_session_config' # Add this
]
__version__ = "1.0.0-phase2"

View File

@@ -3,13 +3,10 @@ Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata.
"""
import json
import threading
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple, Set
from typing import Dict, List, Any, Optional, Tuple
from enum import Enum
from datetime import timezone
from collections import defaultdict
import networkx as nx
@@ -18,8 +15,8 @@ class NodeType(Enum):
"""Enumeration of supported node types."""
DOMAIN = "domain"
IP = "ip"
CERTIFICATE = "certificate"
ASN = "asn"
DNS_RECORD = "dns_record"
LARGE_ENTITY = "large_entity"
@@ -43,6 +40,7 @@ class RelationshipType(Enum):
TLSA_RECORD = ("tlsa_record", 0.7)
NAPTR_RECORD = ("naptr_record", 0.7)
SPF_RECORD = ("spf_record", 0.7)
DNS_RECORD = ("dns_record", 0.8)
PASSIVE_DNS = ("passive_dns", 0.6)
ASN_MEMBERSHIP = ("asn", 0.7)
@@ -115,8 +113,7 @@ class GraphManager:
Returns:
bool: True if edge was added, False if it already exists
"""
#with self.lock:
# Ensure both nodes exist
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
# If the target node is a subdomain, it should be added.
# The scanner will handle this logic.
@@ -149,12 +146,10 @@ class GraphManager:
def get_node_count(self) -> int:
"""Get total number of nodes in the graph."""
#with self.lock:
return self.graph.number_of_nodes()
def get_edge_count(self) -> int:
"""Get total number of edges in the graph."""
#with self.lock:
return self.graph.number_of_edges()
def get_nodes_by_type(self, node_type: NodeType) -> List[str]:
@@ -167,7 +162,6 @@ class GraphManager:
Returns:
List of node identifiers
"""
#with self.lock:
return [
node_id for node_id, attributes in self.graph.nodes(data=True)
if attributes.get('type') == node_type.value
@@ -183,7 +177,6 @@ class GraphManager:
Returns:
List of neighboring node identifiers
"""
#with self.lock:
if not self.graph.has_node(node_id):
return []
@@ -201,7 +194,6 @@ class GraphManager:
Returns:
List of tuples (source, target, attributes)
"""
#with self.lock:
return [
(source, target, attributes)
for source, target, attributes in self.graph.edges(data=True)
@@ -211,46 +203,12 @@ class GraphManager:
def get_graph_data(self) -> Dict[str, Any]:
"""
Export graph data for visualization.
Returns:
Dictionary containing nodes and edges for frontend visualization
Uses comprehensive metadata collected during scanning.
"""
#with self.lock:
nodes = []
edges = []
# Create a dictionary to hold aggregated data for each node
node_details = defaultdict(lambda: defaultdict(list))
for source, target, attributes in self.graph.edges(data=True):
provider = attributes.get('source_provider', 'unknown')
raw_data = attributes.get('raw_data', {})
if provider == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target)
# DNS data is always about the source node of the query
node_details[source]['dns_records'].append(f"{record_type}: {value}")
elif provider == 'crtsh':
# Data from crt.sh are domain names found in certificates (SANs)
node_details[source]['related_domains_san'].append(target)
elif provider == 'shodan':
# Shodan data is about the IP, which can be either the source or target
source_node_type = self.graph.nodes[source].get('type')
target_node_type = self.graph.nodes[target].get('type')
if source_node_type == 'ip':
node_details[source]['shodan'] = raw_data
elif target_node_type == 'ip':
node_details[target]['shodan'] = raw_data
elif provider == 'virustotal':
# VirusTotal data is about the source node of the query
node_details[source]['virustotal'] = raw_data
# Format nodes for visualization
# Create nodes with the comprehensive metadata already collected
for node_id, attributes in self.graph.nodes(data=True):
node_data = {
'id': node_id,
@@ -260,18 +218,15 @@ class GraphManager:
'added_timestamp': attributes.get('added_timestamp')
}
# Add the aggregated details to the metadata
if node_id in node_details:
for key, value in node_details[node_id].items():
# Use a set to avoid adding duplicate entries to lists
if key in node_data['metadata'] and isinstance(node_data['metadata'][key], list):
existing_values = set(node_data['metadata'][key])
new_values = [v for v in value if v not in existing_values]
node_data['metadata'][key].extend(new_values)
else:
node_data['metadata'][key] = value
# Handle certificate node labeling
if node_id.startswith('cert_'):
# For certificate nodes, create a more informative label
cert_metadata = node_data['metadata']
issuer = cert_metadata.get('issuer_name', 'Unknown')
valid_status = "" if cert_metadata.get('is_currently_valid') else ""
node_data['label'] = f"Certificate {valid_status}\n{issuer[:30]}..."
# Color coding by type - now returns color objects for enhanced visualization
# Color coding by type
type_colors = {
'domain': {
'background': '#00ff41',
@@ -285,18 +240,18 @@ class GraphManager:
'highlight': {'background': '#ffbb44', 'border': '#ff9900'},
'hover': {'background': '#ffaa22', 'border': '#dd8800'}
},
'certificate': {
'background': '#c7c7c7',
'border': '#999999',
'highlight': {'background': '#e0e0e0', 'border': '#c7c7c7'},
'hover': {'background': '#d4d4d4', 'border': '#aaaaaa'}
},
'asn': {
'background': '#00aaff',
'border': '#0088cc',
'highlight': {'background': '#44ccff', 'border': '#00aaff'},
'hover': {'background': '#22bbff', 'border': '#0099dd'}
},
'dns_record': {
'background': '#9d4edd',
'border': '#7b2cbf',
'highlight': {'background': '#c77dff', 'border': '#9d4edd'},
'hover': {'background': '#b392f0', 'border': '#8b5cf6'}
},
'large_entity': {
'background': '#ff6b6b',
'border': '#cc3a3a',
@@ -306,15 +261,17 @@ class GraphManager:
}
node_color_config = type_colors.get(attributes.get('type', 'unknown'), type_colors['domain'])
node_data['color'] = node_color_config
# Pass the has_valid_cert metadata for styling
if 'metadata' in attributes and 'has_valid_cert' in attributes['metadata']:
node_data['has_valid_cert'] = attributes['metadata']['has_valid_cert']
# Add certificate validity indicator if available
metadata = node_data['metadata']
if 'certificate_data' in metadata and 'has_valid_cert' in metadata['certificate_data']:
node_data['has_valid_cert'] = metadata['certificate_data']['has_valid_cert']
nodes.append(node_data)
# Format edges for visualization
# Create edges (unchanged from original)
for source, target, attributes in self.graph.edges(data=True):
edge_data = {
'from': source,
@@ -376,7 +333,6 @@ class GraphManager:
Returns:
Dictionary containing complete graph data with metadata
"""
#with self.lock:
# Get basic graph data
graph_data = self.get_graph_data()
@@ -427,7 +383,6 @@ class GraphManager:
Returns:
Dictionary containing various graph metrics
"""
#with self.lock:
stats = {
'basic_metrics': {
'total_nodes': self.graph.number_of_nodes(),
@@ -462,7 +417,6 @@ class GraphManager:
def clear(self) -> None:
"""Clear all nodes and edges from the graph."""
#with self.lock:
self.graph.clear()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time

View File

@@ -3,7 +3,6 @@ Forensic logging system for DNSRecon tool.
Provides structured audit trail for all reconnaissance activities.
"""
import json
import logging
import threading
from datetime import datetime
@@ -109,7 +108,6 @@ class ForensicLogger:
target_indicator: The indicator being investigated
discovery_context: Context of how this indicator was discovered
"""
#with self.lock:
api_request = APIRequest(
timestamp=datetime.now(timezone.utc).isoformat(),
provider=provider,
@@ -152,7 +150,6 @@ class ForensicLogger:
raw_data: Raw data from provider response
discovery_method: Method used to discover relationship
"""
#with self.lock:
relationship = RelationshipDiscovery(
timestamp=datetime.now(timezone.utc).isoformat(),
source_node=source_node,
@@ -178,12 +175,10 @@ class ForensicLogger:
self.logger.info(f"Scan Started - Target: {target_domain}, Depth: {recursion_depth}")
self.logger.info(f"Enabled Providers: {', '.join(enabled_providers)}")
#with self.lock:
self.session_metadata['target_domains'].add(target_domain)
def log_scan_complete(self) -> None:
"""Log the completion of a reconnaissance scan."""
#with self.lock:
self.session_metadata['end_time'] = datetime.now(timezone.utc).isoformat()
self.session_metadata['providers_used'] = list(self.session_metadata['providers_used'])
self.session_metadata['target_domains'] = list(self.session_metadata['target_domains'])
@@ -199,7 +194,6 @@ class ForensicLogger:
Returns:
Dictionary containing complete session audit trail
"""
#with self.lock:
return {
'session_metadata': self.session_metadata.copy(),
'api_requests': [asdict(req) for req in self.api_requests],
@@ -214,7 +208,6 @@ class ForensicLogger:
Returns:
Dictionary containing summary statistics
"""
#with self.lock:
provider_stats = {}
for provider in self.session_metadata['providers_used']:
provider_requests = [req for req in self.api_requests if req.provider == provider]

View File

@@ -4,14 +4,14 @@ Coordinates data gathering from multiple providers and builds the infrastructure
"""
import threading
import time
import traceback
from typing import List, Set, Dict, Any, Optional, Tuple
from typing import List, Set, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
from collections import defaultdict
from core.graph_manager import GraphManager, NodeType, RelationshipType
from core.logger import get_forensic_logger, new_session
from utils.helpers import _is_valid_ip, _is_valid_domain
from providers.crtsh_provider import CrtShProvider
from providers.dns_provider import DNSProvider
from providers.shodan_provider import ShodanProvider
@@ -31,21 +31,27 @@ class ScanStatus:
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
Manages multi-provider data gathering and graph construction with concurrent processing.
Now supports per-session configuration for multi-user isolation.
"""
def __init__(self):
"""Initialize scanner with all available providers and empty graph."""
def __init__(self, session_config=None):
"""Initialize scanner with session-specific configuration."""
print("Initializing Scanner instance...")
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
self.config = session_config
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event() # Use a threading.Event for safer signaling
self.stop_event = threading.Event()
self.scan_thread = None
# Scanning progress tracking
@@ -54,11 +60,11 @@ class Scanner:
self.current_indicator = ""
# Concurrent processing configuration
self.max_workers = config.max_concurrent_requests
self.executor = None # Keep a reference to the executor
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Initialize providers
print("Calling _initialize_providers...")
# Initialize providers with session config
print("Calling _initialize_providers with session config...")
self._initialize_providers()
# Initialize logger
@@ -73,10 +79,10 @@ class Scanner:
raise
def _initialize_providers(self) -> None:
"""Initialize all available providers based on configuration."""
"""Initialize all available providers based on session configuration."""
self.providers = []
print("Initializing providers...")
print("Initializing providers with session config...")
# Always add free providers
free_providers = [
@@ -85,12 +91,15 @@ class Scanner:
]
for provider_name, provider_class in free_providers:
if config.is_provider_enabled(provider_name):
if self.config.is_provider_enabled(provider_name):
try:
provider = provider_class()
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully")
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available")
except Exception as e:
@@ -104,23 +113,41 @@ class Scanner:
]
for provider_name, provider_class in api_providers:
if config.is_provider_enabled(provider_name):
if self.config.is_provider_enabled(provider_name):
try:
provider = provider_class()
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully")
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available (API key required)")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
print(f"Initialized {len(self.providers)} providers")
print(f"Initialized {len(self.providers)} providers for session")
def update_session_config(self, new_config) -> None:
"""
Update session configuration and reinitialize providers.
Args:
new_config: New SessionConfig instance
"""
print("Updating session configuration...")
self.config = new_config
self.max_workers = self.config.max_concurrent_requests
self._initialize_providers()
print("Session configuration updated")
def start_scan(self, target_domain: str, max_depth: int = 2) -> bool:
"""
Start a new reconnaissance scan with concurrent processing.
Enhanced with better debugging and state validation.
Args:
target_domain: Initial domain to investigate
@@ -129,29 +156,36 @@ class Scanner:
Returns:
bool: True if scan started successfully
"""
print(f"Scanner.start_scan called with target='{target_domain}', depth={max_depth}")
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Scanner status: {self.status}")
print(f"Target domain: '{target_domain}', Max depth: {max_depth}")
print(f"Available providers: {len(self.providers) if hasattr(self, 'providers') else 0}")
try:
if self.status == ScanStatus.RUNNING:
print("Scan already running, rejecting new scan")
print(f"ERROR: Scan already running in scanner {id(self)}, rejecting new scan")
print(f"Current target: {self.current_target}")
print(f"Current depth: {self.current_depth}")
return False
# Check if we have any providers
if not self.providers:
print("No providers available, cannot start scan")
if not hasattr(self, 'providers') or not self.providers:
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
return False
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
# Stop any existing scan thread
if self.scan_thread and self.scan_thread.is_alive():
print("Stopping existing scan thread...")
print(f"Stopping existing scan thread in scanner {id(self)}...")
self.stop_event.set()
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
print("WARNING: Could not stop existing thread")
print(f"WARNING: Could not stop existing thread in scanner {id(self)}")
return False
# Reset state
print("Resetting scanner state...")
print(f"Resetting scanner {id(self)} state...")
self.graph.clear()
self.current_target = target_domain.lower().strip()
self.max_depth = max_depth
@@ -162,11 +196,11 @@ class Scanner:
self.current_indicator = self.current_target
# Start new forensic session
print("Starting new forensic session...")
print(f"Starting new forensic session for scanner {id(self)}...")
self.logger = new_session()
# Start scan in separate thread
print("Starting scan thread...")
print(f"Starting scan thread for scanner {id(self)}...")
self.scan_thread = threading.Thread(
target=self._execute_scan_async,
args=(self.current_target, max_depth),
@@ -174,14 +208,15 @@ class Scanner:
)
self.scan_thread.start()
print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===")
return True
except Exception as e:
print(f"ERROR: Exception in start_scan: {e}")
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc()
return False
def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
async def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
"""
Execute the reconnaissance scan asynchronously with concurrent provider queries.
@@ -210,7 +245,7 @@ class Scanner:
processed_domains = set()
all_discovered_ips = set()
print(f"Starting BFS exploration...")
print("Starting BFS exploration...")
for depth in range(max_depth + 1):
if self.stop_event.is_set():
@@ -269,7 +304,7 @@ class Scanner:
self.executor.shutdown(wait=False, cancel_futures=True)
stats = self.graph.get_statistics()
print(f"Final scan statistics:")
print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Domains processed: {len(processed_domains)}")
@@ -330,87 +365,243 @@ class Scanner:
def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]:
"""
Query all enabled providers for information about a domain.
Query all enabled providers for information about a domain and collect comprehensive metadata.
Creates appropriate node types and relationships based on discovered data.
"""
print(f"Querying {len(self.providers)} providers for domain: {domain}")
discovered_domains = set()
discovered_ips = set()
relationships_by_type = defaultdict(list)
all_relationships = []
# Comprehensive metadata collection for this domain
domain_metadata = {
'dns_records': [],
'related_domains_san': [],
'shodan': {},
'virustotal': {},
'certificate_data': {},
'passive_dns': [],
}
if not self.providers or self.stop_event.is_set():
return discovered_domains, discovered_ips
# Query all providers concurrently
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider
for provider in self.providers
}
for future in as_completed(future_to_provider):
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships")
# Process relationships and collect metadata
for rel in relationships:
relationships_by_type[rel[2]].append(rel)
source, target, rel_type, confidence, raw_data = rel
# Add provider info to the relationship
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
all_relationships.append(enhanced_rel)
# Collect metadata based on provider and relationship type
self._collect_node_metadata(domain, provider.get_name(), rel_type, target, raw_data, domain_metadata)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for {domain}: {e}")
# Add the domain node with comprehensive metadata
self.graph.add_node(domain, NodeType.DOMAIN, metadata=domain_metadata)
# Group relationships by type for large entity handling
relationships_by_type = defaultdict(list)
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
relationships_by_type[rel_type].append((source, target, rel_type, confidence, raw_data, provider_name))
# Handle large entities (only for SAN certificates currently)
for rel_type, relationships in relationships_by_type.items():
if len(relationships) > config.large_entity_threshold and rel_type == RelationshipType.SAN_CERTIFICATE:
self._handle_large_entity(domain, relationships, rel_type, provider.get_name())
first_provider = relationships[0][5] if relationships else "multiple_providers"
self._handle_large_entity(domain, relationships, rel_type, first_provider)
# Remove these relationships from further processing
all_relationships = [rel for rel in all_relationships if not (rel[2] == rel_type and len(relationships_by_type[rel_type]) > config.large_entity_threshold)]
# Track DNS records to create (avoid duplicates)
dns_records_to_create = {}
# Process remaining relationships
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
if self.stop_event.is_set():
break
# Determine how to handle the target based on relationship type and content
if _is_valid_ip(target):
# Create IP node and relationship
self.graph.add_node(target, NodeType.IP)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add to recursion if it's a direct resolution
if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]:
discovered_ips.add(target)
elif target.startswith('AS') and target[2:].isdigit():
# Create ASN node and relationship
self.graph.add_node(target, NodeType.ASN)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added ASN relationship: {source} -> {target} ({rel_type.relationship_name})")
elif _is_valid_domain(target):
# Create domain node and relationship
self.graph.add_node(target, NodeType.DOMAIN)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added domain relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add to recursion for specific relationship types
recurse_types = [
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.SAN_CERTIFICATE,
RelationshipType.NS_RECORD,
RelationshipType.PASSIVE_DNS
]
if rel_type in recurse_types:
discovered_domains.add(target)
else:
for source, target, rel_type, confidence, raw_data in relationships:
# Determine if the target should create a new node
create_node = rel_type in [
RelationshipType.A_RECORD,
RelationshipType.AAAA_RECORD,
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.NS_RECORD,
RelationshipType.PTR_RECORD,
RelationshipType.SAN_CERTIFICATE
]
# Handle DNS record content (TXT, SPF, CAA, etc.)
dns_record_types = [
RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD,
RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD,
RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD,
RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD,
RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD
]
if rel_type in dns_record_types:
# Create normalized DNS record identifier
record_type = rel_type.relationship_name.upper().replace('_RECORD', '')
record_content = target.strip()
# Create a unique identifier for this DNS record
content_hash = hash(record_content) & 0x7FFFFFFF
dns_record_id = f"{record_type}:{content_hash}"
# Track this DNS record for creation (avoid duplicates)
if dns_record_id not in dns_records_to_create:
dns_records_to_create[dns_record_id] = {
'content': record_content,
'type': record_type,
'domains': set(),
'raw_data': raw_data,
'provider_name': provider_name,
'confidence': confidence
}
# Add this domain to the DNS record's domain list
dns_records_to_create[dns_record_id]['domains'].add(source)
print(f"DNS record tracked: {source} -> {record_type} (content length: {len(record_content)})")
else:
# For other non-infrastructure targets, log but don't create nodes
print(f"Non-infrastructure relationship stored as metadata: {source} - {rel_type.relationship_name}: {target[:100]}")
# Determine if the target should be subject to recursion
recurse = rel_type in [
RelationshipType.A_RECORD,
RelationshipType.AAAA_RECORD,
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.SAN_CERTIFICATE
]
# Create DNS record nodes and their relationships
for dns_record_id, record_info in dns_records_to_create.items():
if self.stop_event.is_set():
break
record_metadata = {
'record_type': record_info['type'],
'content': record_info['content'],
'content_hash': dns_record_id.split(':')[1],
'associated_domains': list(record_info['domains']),
'source_data': record_info['raw_data']
}
# Create the DNS record node
self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata)
# Connect each domain to this DNS record
for domain_name in record_info['domains']:
if self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD,
record_info['confidence'], record_info['provider_name'],
record_info['raw_data']):
print(f"Added DNS record relationship: {domain_name} -> {dns_record_id}")
if create_node:
target_node_type = NodeType.IP if self._is_valid_ip(target) else NodeType.DOMAIN
self.graph.add_node(target, target_node_type)
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added relationship: {source} -> {target} ({rel_type.relationship_name})")
else:
# For records that don't create nodes, we still want to log the relationship
self.logger.log_relationship_discovery(
source_node=source,
target_node=target,
relationship_type=rel_type.relationship_name,
confidence_score=confidence,
provider=provider.name,
raw_data=raw_data,
discovery_method=f"dns_{rel_type.name.lower()}_record"
)
if recurse:
if self._is_valid_ip(target):
discovered_ips.add(target)
elif self._is_valid_domain(target):
discovered_domains.add(target)
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs")
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs, {len(dns_records_to_create)} DNS records")
return discovered_domains, discovered_ips
def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
"""
Collect and organize metadata for a node based on provider responses.
"""
if provider_name == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target)
# For non-infrastructure DNS records, store the full content
if record_type in ['TXT', 'SPF', 'CAA']:
dns_entry = f"{record_type}: {value}"
else:
dns_entry = f"{record_type}: {value}"
if dns_entry not in metadata['dns_records']:
metadata['dns_records'].append(dns_entry)
elif provider_name == 'crtsh':
if rel_type == RelationshipType.SAN_CERTIFICATE:
# Handle certificate data storage on domain nodes
domain_certs = raw_data.get('domain_certificates', {})
# Store certificate information for this domain
if node_id in domain_certs:
cert_summary = domain_certs[node_id]
# Update domain metadata with certificate information
metadata['certificate_data'] = cert_summary
metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False)
# Add related domains from shared certificates
if target not in metadata.get('related_domains_san', []):
if 'related_domains_san' not in metadata:
metadata['related_domains_san'] = []
metadata['related_domains_san'].append(target)
# Store shared certificate details for forensic analysis
shared_certs = raw_data.get('shared_certificates', [])
if shared_certs and 'shared_certificate_details' not in metadata:
metadata['shared_certificate_details'] = shared_certs
elif provider_name == 'shodan':
# Merge Shodan data (avoid overwriting)
for key, value in raw_data.items():
if key not in metadata['shodan'] or not metadata['shodan'][key]:
metadata['shodan'][key] = value
elif provider_name == 'virustotal':
# Merge VirusTotal data
for key, value in raw_data.items():
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
metadata['virustotal'][key] = value
# Add passive DNS entries
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata['passive_dns']:
metadata['passive_dns'].append(passive_entry)
def _handle_large_entity(self, source_domain: str, relationships: list, rel_type: RelationshipType, provider_name: str):
"""
Handles the creation of a large entity node when a threshold is exceeded.
@@ -422,12 +613,24 @@ class Scanner:
def _query_providers_for_ip(self, ip: str) -> None:
"""
Query all enabled providers for information about an IP address.
Query all enabled providers for information about an IP address and collect comprehensive metadata.
"""
print(f"Querying {len(self.providers)} providers for IP: {ip}")
if not self.providers or self.stop_event.is_set():
return
# Comprehensive metadata collection for this IP
ip_metadata = {
'dns_records': [],
'passive_dns': [],
'shodan': {},
'virustotal': {},
'asn_data': {},
'hostnames': [],
}
all_relationships = [] # Store relationships with provider info
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider
@@ -441,20 +644,86 @@ class Scanner:
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}")
for source, target, rel_type, confidence, raw_data in relationships:
if self._is_valid_domain(target):
target_node_type = NodeType.DOMAIN
elif target.startswith('AS'):
target_node_type = NodeType.ASN
else:
target_node_type = NodeType.IP
self.graph.add_node(source, NodeType.IP)
self.graph.add_node(target, target_node_type)
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add provider info to the relationship
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
all_relationships.append(enhanced_rel)
# Collect metadata for the IP
self._collect_ip_metadata(ip, provider.get_name(), rel_type, target, raw_data, ip_metadata)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
# Update the IP node with comprehensive metadata
self.graph.add_node(ip, NodeType.IP, metadata=ip_metadata)
# Process relationships with correct provider attribution
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
# Determine target node type
if _is_valid_domain(target):
target_node_type = NodeType.DOMAIN
elif target.startswith('AS'):
target_node_type = NodeType.ASN
else:
target_node_type = NodeType.IP
# Create/update target node
self.graph.add_node(target, target_node_type)
# Add relationship with correct provider attribution
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name}) from {provider_name}")
def _collect_ip_metadata(self, ip: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
"""
Collect and organize metadata for an IP node based on provider responses.
"""
if provider_name == 'dns':
if rel_type == RelationshipType.PTR_RECORD:
reverse_entry = f"PTR: {target}"
if reverse_entry not in metadata['dns_records']:
metadata['dns_records'].append(reverse_entry)
if target not in metadata['hostnames']:
metadata['hostnames'].append(target)
elif provider_name == 'shodan':
# Merge Shodan data
for key, value in raw_data.items():
if key not in metadata['shodan'] or not metadata['shodan'][key]:
metadata['shodan'][key] = value
# Collect hostname information
if 'hostname' in raw_data and raw_data['hostname'] not in metadata['hostnames']:
metadata['hostnames'].append(raw_data['hostname'])
if 'hostnames' in raw_data:
for hostname in raw_data['hostnames']:
if hostname not in metadata['hostnames']:
metadata['hostnames'].append(hostname)
elif provider_name == 'virustotal':
# Merge VirusTotal data
for key, value in raw_data.items():
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
metadata['virustotal'][key] = value
# Add passive DNS entries
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata['passive_dns']:
metadata['passive_dns'].append(passive_entry)
# Handle ASN relationships
if rel_type == RelationshipType.ASN_MEMBERSHIP:
metadata['asn_data'] = {
'asn': target,
'description': raw_data.get('org', ''),
'isp': raw_data.get('isp', ''),
'country': raw_data.get('country', '')
}
def _safe_provider_query_domain(self, provider, domain: str):
"""Safely query provider for domain with error handling."""
@@ -478,12 +747,33 @@ class Scanner:
def stop_scan(self) -> bool:
"""
Request scan termination.
Request immediate scan termination with aggressive cancellation.
"""
try:
if self.status == ScanStatus.RUNNING:
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
# Signal all threads to stop
self.stop_event.set()
print("Scan stop requested")
# Close HTTP sessions in all providers to terminate ongoing requests
for provider in self.providers:
try:
if hasattr(provider, 'session'):
provider.session.close()
print(f"Closed HTTP session for provider: {provider.get_name()}")
except Exception as e:
print(f"Error closing session for {provider.get_name()}: {e}")
# Shutdown executor immediately with cancel_futures=True
if self.executor:
print("Shutting down executor with immediate cancellation...")
self.executor.shutdown(wait=False, cancel_futures=True)
# Give threads a moment to respond to cancellation, then force status change
threading.Timer(2.0, self._force_stop_completion).start()
print("Immediate termination requested - ongoing requests will be cancelled")
return True
print("No active scan to stop")
return False
@@ -492,6 +782,13 @@ class Scanner:
traceback.print_exc()
return False
def _force_stop_completion(self):
"""Force completion of stop operation after timeout."""
if self.status == ScanStatus.RUNNING:
print("Forcing scan termination after timeout")
self.status = ScanStatus.STOPPED
self.logger.log_scan_complete()
def get_scan_status(self) -> Dict[str, Any]:
"""
Get current scan status and progress.
@@ -561,16 +858,6 @@ class Scanner:
}
return export_data
def remove_provider(self, provider_name: str) -> bool:
"""
Remove a provider from the scanner.
"""
for i, provider in enumerate(self.providers):
if provider.get_name() == provider_name:
self.providers.pop(i)
return True
return False
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""
Get statistics for all providers.
@@ -578,53 +865,4 @@ class Scanner:
stats = {}
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
"""
if not domain or len(domain) > 253:
return False
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False
class ScannerProxy:
def __init__(self):
self._scanner = None
print("ScannerProxy initialized")
def __getattr__(self, name):
if self._scanner is None:
print("Creating new Scanner instance...")
self._scanner = Scanner()
print("Scanner instance created")
return getattr(self._scanner, name)
# Global scanner instance
scanner = ScannerProxy()
return stats

126
core/session_config.py Normal file
View File

@@ -0,0 +1,126 @@
"""
Per-session configuration management for DNSRecon.
Provides isolated configuration instances for each user session.
"""
import os
from typing import Dict, Optional
class SessionConfig:
"""
Session-specific configuration that inherits from global config
but maintains isolated API keys and provider settings.
"""
def __init__(self):
"""Initialize session config with global defaults."""
# Copy all attributes from global config
self.api_keys: Dict[str, Optional[str]] = {
'shodan': None,
'virustotal': None
}
# Default settings (copied from global config)
self.default_recursion_depth = 2
self.default_timeout = 30
self.max_concurrent_requests = 5
self.large_entity_threshold = 100
# Rate limiting settings (per session)
self.rate_limits = {
'crtsh': 60,
'virustotal': 4,
'shodan': 60,
'dns': 100
}
# Provider settings (per session)
self.enabled_providers = {
'crtsh': True,
'dns': True,
'virustotal': False,
'shodan': False
}
# Logging configuration
self.log_level = 'INFO'
self.log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# Flask configuration (shared)
self.flask_host = '127.0.0.1'
self.flask_port = 5000
self.flask_debug = True
def set_api_key(self, provider: str, api_key: str) -> bool:
"""
Set API key for a provider in this session.
Args:
provider: Provider name (shodan, virustotal)
api_key: API key string
Returns:
bool: True if key was set successfully
"""
if provider in self.api_keys:
self.api_keys[provider] = api_key
self.enabled_providers[provider] = True if api_key else False
return True
return False
def get_api_key(self, provider: str) -> Optional[str]:
"""
Get API key for a provider in this session.
Args:
provider: Provider name
Returns:
API key or None if not set
"""
return self.api_keys.get(provider)
def is_provider_enabled(self, provider: str) -> bool:
"""
Check if a provider is enabled in this session.
Args:
provider: Provider name
Returns:
bool: True if provider is enabled
"""
return self.enabled_providers.get(provider, False)
def get_rate_limit(self, provider: str) -> int:
"""
Get rate limit for a provider in this session.
Args:
provider: Provider name
Returns:
Rate limit in requests per minute
"""
return self.rate_limits.get(provider, 60)
def load_from_env(self):
"""Load configuration from environment variables (only if not already set)."""
if os.getenv('VIRUSTOTAL_API_KEY') and not self.api_keys['virustotal']:
self.set_api_key('virustotal', os.getenv('VIRUSTOTAL_API_KEY'))
if os.getenv('SHODAN_API_KEY') and not self.api_keys['shodan']:
self.set_api_key('shodan', os.getenv('SHODAN_API_KEY'))
# Override default settings from environment
self.default_recursion_depth = int(os.getenv('DEFAULT_RECURSION_DEPTH', '2'))
self.default_timeout = 30
self.max_concurrent_requests = 5
def create_session_config() -> SessionConfig:
"""Create a new session configuration instance."""
session_config = SessionConfig()
session_config.load_from_env()
return session_config

281
core/session_manager.py Normal file
View File

@@ -0,0 +1,281 @@
"""
Session manager for DNSRecon multi-user support.
Manages individual scanner instances per user session with automatic cleanup.
"""
import threading
import time
import uuid
from typing import Dict, Optional, Any
from datetime import datetime, timezone
from core.scanner import Scanner
class SessionManager:
"""
Manages multiple scanner instances for concurrent user sessions.
Provides session isolation and automatic cleanup of inactive sessions.
"""
def __init__(self, session_timeout_minutes: int = 60):
"""
Initialize session manager.
Args:
session_timeout_minutes: Minutes of inactivity before session cleanup
"""
self.sessions: Dict[str, Dict[str, Any]] = {}
self.session_timeout = session_timeout_minutes * 60 # Convert to seconds
self.lock = threading.Lock()
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
print(f"SessionManager initialized with {session_timeout_minutes}min timeout")
def create_session(self) -> str:
"""
Create a new user session with dedicated scanner instance and configuration.
Enhanced with better debugging and race condition protection.
Returns:
Unique session ID
"""
session_id = str(uuid.uuid4())
print(f"=== CREATING SESSION {session_id} ===")
try:
# Create session-specific configuration
from core.session_config import create_session_config
session_config = create_session_config()
print(f"Created session config for {session_id}")
# Create scanner with session config
from core.scanner import Scanner
scanner_instance = Scanner(session_config=session_config)
print(f"Created scanner instance {id(scanner_instance)} for session {session_id}")
print(f"Initial scanner status: {scanner_instance.status}")
with self.lock:
self.sessions[session_id] = {
'scanner': scanner_instance,
'config': session_config,
'created_at': time.time(),
'last_activity': time.time(),
'user_agent': '',
'status': 'active'
}
print(f"Session {session_id} stored in session manager")
print(f"Total active sessions: {len([s for s in self.sessions.values() if s['status'] == 'active'])}")
print(f"=== SESSION {session_id} CREATED SUCCESSFULLY ===")
return session_id
except Exception as e:
print(f"ERROR: Failed to create session {session_id}: {e}")
raise
def get_session(self, session_id: str) -> Optional[object]:
"""
Get scanner instance for a session with enhanced debugging.
Args:
session_id: Session identifier
Returns:
Scanner instance or None if session doesn't exist
"""
if not session_id:
print("get_session called with empty session_id")
return None
with self.lock:
if session_id not in self.sessions:
print(f"Session {session_id} not found in session manager")
print(f"Available sessions: {list(self.sessions.keys())}")
return None
session_data = self.sessions[session_id]
# Check if session is still active
if session_data['status'] != 'active':
print(f"Session {session_id} is not active (status: {session_data['status']})")
return None
# Update last activity
session_data['last_activity'] = time.time()
scanner = session_data['scanner']
print(f"Retrieved scanner {id(scanner)} for session {session_id}")
print(f"Scanner status: {scanner.status}")
return scanner
def get_or_create_session(self, session_id: Optional[str] = None) -> tuple[str, Scanner]:
"""
Get existing session or create new one.
Args:
session_id: Optional existing session ID
Returns:
Tuple of (session_id, scanner_instance)
"""
if session_id and self.get_session(session_id):
return session_id, self.get_session(session_id)
else:
new_session_id = self.create_session()
return new_session_id, self.get_session(new_session_id)
def terminate_session(self, session_id: str) -> bool:
"""
Terminate a specific session and cleanup resources.
Args:
session_id: Session to terminate
Returns:
True if session was terminated successfully
"""
with self.lock:
if session_id not in self.sessions:
return False
session_data = self.sessions[session_id]
scanner = session_data['scanner']
# Stop any running scan
try:
if scanner.status == 'running':
scanner.stop_scan()
print(f"Stopped scan for session: {session_id}")
except Exception as e:
print(f"Error stopping scan for session {session_id}: {e}")
# Mark as terminated
session_data['status'] = 'terminated'
session_data['terminated_at'] = time.time()
# Remove from active sessions after a brief delay to allow cleanup
threading.Timer(5.0, lambda: self._remove_session(session_id)).start()
print(f"Terminated session: {session_id}")
return True
def _remove_session(self, session_id: str) -> None:
"""Remove session from memory."""
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
print(f"Removed session from memory: {session_id}")
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Get session information without updating activity.
Args:
session_id: Session identifier
Returns:
Session information dictionary or None
"""
with self.lock:
if session_id not in self.sessions:
return None
session_data = self.sessions[session_id]
scanner = session_data['scanner']
return {
'session_id': session_id,
'created_at': datetime.fromtimestamp(session_data['created_at'], timezone.utc).isoformat(),
'last_activity': datetime.fromtimestamp(session_data['last_activity'], timezone.utc).isoformat(),
'status': session_data['status'],
'scan_status': scanner.status,
'current_target': scanner.current_target,
'uptime_seconds': time.time() - session_data['created_at']
}
def list_active_sessions(self) -> Dict[str, Dict[str, Any]]:
"""
List all active sessions with enhanced debugging info.
Returns:
Dictionary of session information
"""
active_sessions = {}
with self.lock:
for session_id, session_data in self.sessions.items():
if session_data['status'] == 'active':
scanner = session_data['scanner']
active_sessions[session_id] = {
'session_id': session_id,
'created_at': datetime.fromtimestamp(session_data['created_at'], timezone.utc).isoformat(),
'last_activity': datetime.fromtimestamp(session_data['last_activity'], timezone.utc).isoformat(),
'status': session_data['status'],
'scan_status': scanner.status,
'current_target': scanner.current_target,
'uptime_seconds': time.time() - session_data['created_at'],
'scanner_object_id': id(scanner)
}
return active_sessions
def _cleanup_loop(self) -> None:
"""Background thread to cleanup inactive sessions."""
while True:
try:
current_time = time.time()
sessions_to_cleanup = []
with self.lock:
for session_id, session_data in self.sessions.items():
if session_data['status'] != 'active':
continue
inactive_time = current_time - session_data['last_activity']
if inactive_time > self.session_timeout:
sessions_to_cleanup.append(session_id)
# Cleanup outside of lock to avoid deadlock
for session_id in sessions_to_cleanup:
print(f"Cleaning up inactive session: {session_id}")
self.terminate_session(session_id)
# Sleep for 5 minutes between cleanup cycles
time.sleep(300)
except Exception as e:
print(f"Error in session cleanup loop: {e}")
time.sleep(60) # Sleep for 1 minute on error
def get_statistics(self) -> Dict[str, Any]:
"""
Get session manager statistics.
Returns:
Statistics dictionary
"""
with self.lock:
active_count = sum(1 for s in self.sessions.values() if s['status'] == 'active')
running_scans = sum(1 for s in self.sessions.values()
if s['status'] == 'active' and s['scanner'].status == 'running')
return {
'total_sessions': len(self.sessions),
'active_sessions': active_count,
'running_scans': running_scans,
'session_timeout_minutes': self.session_timeout / 60
}
# Global session manager instance
session_manager = SessionManager(session_timeout_minutes=60)