dnsrecon/core/scanner.py
overcuriousity d3e1fcf35f it
2025-09-11 14:01:15 +02:00

868 lines
38 KiB
Python

"""
Main scanning orchestrator for DNSRecon.
Coordinates data gathering from multiple providers and builds the infrastructure graph.
"""
import threading
import traceback
from typing import List, Set, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
from collections import defaultdict
from core.graph_manager import GraphManager, NodeType, RelationshipType
from core.logger import get_forensic_logger, new_session
from utils.helpers import _is_valid_ip, _is_valid_domain
from providers.crtsh_provider import CrtShProvider
from providers.dns_provider import DNSProvider
from providers.shodan_provider import ShodanProvider
from providers.virustotal_provider import VirusTotalProvider
from config import config
class ScanStatus:
"""Enumeration of scan statuses."""
IDLE = "idle"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
STOPPED = "stopped"
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
Now supports per-session configuration for multi-user isolation.
"""
def __init__(self, session_config=None):
"""Initialize scanner with session-specific configuration."""
print("Initializing Scanner instance...")
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
self.config = session_config
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event()
self.scan_thread = None
# Scanning progress tracking
self.total_indicators_found = 0
self.indicators_processed = 0
self.current_indicator = ""
# Concurrent processing configuration
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Initialize providers with session config
print("Calling _initialize_providers with session config...")
self._initialize_providers()
# Initialize logger
print("Initializing forensic logger...")
self.logger = get_forensic_logger()
print("Scanner initialization complete")
except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc()
raise
def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration."""
self.providers = []
print("Initializing providers with session config...")
# Always add free providers
free_providers = [
('crtsh', CrtShProvider),
('dns', DNSProvider)
]
for provider_name, provider_class in free_providers:
if self.config.is_provider_enabled(provider_name):
try:
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
# Add API key-dependent providers
api_providers = [
('shodan', ShodanProvider),
('virustotal', VirusTotalProvider)
]
for provider_name, provider_class in api_providers:
if self.config.is_provider_enabled(provider_name):
try:
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available (API key required)")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
print(f"Initialized {len(self.providers)} providers for session")
def update_session_config(self, new_config) -> None:
"""
Update session configuration and reinitialize providers.
Args:
new_config: New SessionConfig instance
"""
print("Updating session configuration...")
self.config = new_config
self.max_workers = self.config.max_concurrent_requests
self._initialize_providers()
print("Session configuration updated")
def start_scan(self, target_domain: str, max_depth: int = 2) -> bool:
"""
Start a new reconnaissance scan with concurrent processing.
Enhanced with better debugging and state validation.
Args:
target_domain: Initial domain to investigate
max_depth: Maximum recursion depth
Returns:
bool: True if scan started successfully
"""
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Scanner status: {self.status}")
print(f"Target domain: '{target_domain}', Max depth: {max_depth}")
print(f"Available providers: {len(self.providers) if hasattr(self, 'providers') else 0}")
try:
if self.status == ScanStatus.RUNNING:
print(f"ERROR: Scan already running in scanner {id(self)}, rejecting new scan")
print(f"Current target: {self.current_target}")
print(f"Current depth: {self.current_depth}")
return False
# Check if we have any providers
if not hasattr(self, 'providers') or not self.providers:
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
return False
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
# Stop any existing scan thread
if self.scan_thread and self.scan_thread.is_alive():
print(f"Stopping existing scan thread in scanner {id(self)}...")
self.stop_event.set()
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
print(f"WARNING: Could not stop existing thread in scanner {id(self)}")
return False
# Reset state
print(f"Resetting scanner {id(self)} state...")
self.graph.clear()
self.current_target = target_domain.lower().strip()
self.max_depth = max_depth
self.current_depth = 0
self.stop_event.clear()
self.total_indicators_found = 0
self.indicators_processed = 0
self.current_indicator = self.current_target
# Start new forensic session
print(f"Starting new forensic session for scanner {id(self)}...")
self.logger = new_session()
# Start scan in separate thread
print(f"Starting scan thread for scanner {id(self)}...")
self.scan_thread = threading.Thread(
target=self._execute_scan_async,
args=(self.current_target, max_depth),
daemon=True
)
self.scan_thread.start()
print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===")
return True
except Exception as e:
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc()
return False
async def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
"""
Execute the reconnaissance scan asynchronously with concurrent provider queries.
Args:
target_domain: Target domain to investigate
max_depth: Maximum recursion depth
"""
print(f"_execute_scan_async started for {target_domain} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
try:
print("Setting status to RUNNING")
self.status = ScanStatus.RUNNING
# Log scan start
enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
print(f"Logged scan start with providers: {enabled_providers}")
# Initialize with target domain
print(f"Adding target domain '{target_domain}' as initial node")
self.graph.add_node(target_domain, NodeType.DOMAIN)
# BFS-style exploration
current_level_domains = {target_domain}
processed_domains = set()
all_discovered_ips = set()
print("Starting BFS exploration...")
for depth in range(max_depth + 1):
if self.stop_event.is_set():
print(f"Stop requested at depth {depth}")
break
self.current_depth = depth
print(f"Processing depth level {depth} with {len(current_level_domains)} domains")
if not current_level_domains:
print("No domains to process at this level")
break
self.total_indicators_found += len(current_level_domains)
next_level_domains = set()
domain_results = self._process_domains_concurrent(current_level_domains, processed_domains)
for domain, discovered_domains, discovered_ips in domain_results:
if self.stop_event.is_set():
break
processed_domains.add(domain)
all_discovered_ips.update(discovered_ips)
if depth < max_depth:
for discovered_domain in discovered_domains:
if discovered_domain not in processed_domains:
next_level_domains.add(discovered_domain)
print(f"Adding {discovered_domain} to next level")
if self.stop_event.is_set():
break
if all_discovered_ips:
print(f"Processing {len(all_discovered_ips)} discovered IP addresses")
self._process_ips_concurrent(all_discovered_ips)
current_level_domains = next_level_domains
print(f"Completed depth {depth}, {len(next_level_domains)} domains for next level")
except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
if self.stop_event.is_set():
self.status = ScanStatus.STOPPED
print("Scan completed with STOPPED status")
else:
self.status = ScanStatus.COMPLETED
print("Scan completed with COMPLETED status")
self.logger.log_scan_complete()
self.executor.shutdown(wait=False, cancel_futures=True)
stats = self.graph.get_statistics()
print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Domains processed: {len(processed_domains)}")
print(f" - IPs discovered: {len(all_discovered_ips)}")
def _process_domains_concurrent(self, domains: Set[str], processed_domains: Set[str]) -> List[Tuple[str, Set[str], Set[str]]]:
"""
Process multiple domains concurrently using thread pool.
"""
results = []
domains_to_process = domains - processed_domains
if not domains_to_process:
return results
print(f"Processing {len(domains_to_process)} domains concurrently with {self.max_workers} workers")
future_to_domain = {
self.executor.submit(self._query_providers_for_domain, domain): domain
for domain in domains_to_process
}
for future in as_completed(future_to_domain):
if self.stop_event.is_set():
future.cancel()
continue
domain = future_to_domain[future]
try:
discovered_domains, discovered_ips = future.result()
results.append((domain, discovered_domains, discovered_ips))
self.indicators_processed += 1
print(f"Completed processing domain: {domain} ({len(discovered_domains)} domains, {len(discovered_ips)} IPs)")
except (Exception, CancelledError) as e:
print(f"Error processing domain {domain}: {e}")
return results
def _process_ips_concurrent(self, ips: Set[str]) -> None:
"""
Process multiple IP addresses concurrently.
"""
if not ips or self.stop_event.is_set():
return
print(f"Processing {len(ips)} IP addresses concurrently")
future_to_ip = {
self.executor.submit(self._query_providers_for_ip, ip): ip
for ip in ips
}
for future in as_completed(future_to_ip):
if self.stop_event.is_set():
future.cancel()
continue
ip = future_to_ip[future]
try:
future.result() # Just wait for completion
print(f"Completed processing IP: {ip}")
except (Exception, CancelledError) as e:
print(f"Error processing IP {ip}: {e}")
def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]:
"""
Query all enabled providers for information about a domain and collect comprehensive metadata.
Creates appropriate node types and relationships based on discovered data.
"""
print(f"Querying {len(self.providers)} providers for domain: {domain}")
discovered_domains = set()
discovered_ips = set()
all_relationships = []
# Comprehensive metadata collection for this domain
domain_metadata = {
'dns_records': [],
'related_domains_san': [],
'shodan': {},
'virustotal': {},
'certificate_data': {},
'passive_dns': [],
}
if not self.providers or self.stop_event.is_set():
return discovered_domains, discovered_ips
# Query all providers concurrently
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider
for provider in self.providers
}
for future in as_completed(future_to_provider):
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships")
# Process relationships and collect metadata
for rel in relationships:
source, target, rel_type, confidence, raw_data = rel
# Add provider info to the relationship
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
all_relationships.append(enhanced_rel)
# Collect metadata based on provider and relationship type
self._collect_node_metadata(domain, provider.get_name(), rel_type, target, raw_data, domain_metadata)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for {domain}: {e}")
# Add the domain node with comprehensive metadata
self.graph.add_node(domain, NodeType.DOMAIN, metadata=domain_metadata)
# Group relationships by type for large entity handling
relationships_by_type = defaultdict(list)
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
relationships_by_type[rel_type].append((source, target, rel_type, confidence, raw_data, provider_name))
# Handle large entities (only for SAN certificates currently)
for rel_type, relationships in relationships_by_type.items():
if len(relationships) > config.large_entity_threshold and rel_type == RelationshipType.SAN_CERTIFICATE:
first_provider = relationships[0][5] if relationships else "multiple_providers"
self._handle_large_entity(domain, relationships, rel_type, first_provider)
# Remove these relationships from further processing
all_relationships = [rel for rel in all_relationships if not (rel[2] == rel_type and len(relationships_by_type[rel_type]) > config.large_entity_threshold)]
# Track DNS records to create (avoid duplicates)
dns_records_to_create = {}
# Process remaining relationships
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
if self.stop_event.is_set():
break
# Determine how to handle the target based on relationship type and content
if _is_valid_ip(target):
# Create IP node and relationship
self.graph.add_node(target, NodeType.IP)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add to recursion if it's a direct resolution
if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]:
discovered_ips.add(target)
elif target.startswith('AS') and target[2:].isdigit():
# Create ASN node and relationship
self.graph.add_node(target, NodeType.ASN)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added ASN relationship: {source} -> {target} ({rel_type.relationship_name})")
elif _is_valid_domain(target):
# Create domain node and relationship
self.graph.add_node(target, NodeType.DOMAIN)
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added domain relationship: {source} -> {target} ({rel_type.relationship_name})")
# Add to recursion for specific relationship types
recurse_types = [
RelationshipType.CNAME_RECORD,
RelationshipType.MX_RECORD,
RelationshipType.SAN_CERTIFICATE,
RelationshipType.NS_RECORD,
RelationshipType.PASSIVE_DNS
]
if rel_type in recurse_types:
discovered_domains.add(target)
else:
# Handle DNS record content (TXT, SPF, CAA, etc.)
dns_record_types = [
RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD,
RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD,
RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD,
RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD,
RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD
]
if rel_type in dns_record_types:
# Create normalized DNS record identifier
record_type = rel_type.relationship_name.upper().replace('_RECORD', '')
record_content = target.strip()
# Create a unique identifier for this DNS record
content_hash = hash(record_content) & 0x7FFFFFFF
dns_record_id = f"{record_type}:{content_hash}"
# Track this DNS record for creation (avoid duplicates)
if dns_record_id not in dns_records_to_create:
dns_records_to_create[dns_record_id] = {
'content': record_content,
'type': record_type,
'domains': set(),
'raw_data': raw_data,
'provider_name': provider_name,
'confidence': confidence
}
# Add this domain to the DNS record's domain list
dns_records_to_create[dns_record_id]['domains'].add(source)
print(f"DNS record tracked: {source} -> {record_type} (content length: {len(record_content)})")
else:
# For other non-infrastructure targets, log but don't create nodes
print(f"Non-infrastructure relationship stored as metadata: {source} - {rel_type.relationship_name}: {target[:100]}")
# Create DNS record nodes and their relationships
for dns_record_id, record_info in dns_records_to_create.items():
if self.stop_event.is_set():
break
record_metadata = {
'record_type': record_info['type'],
'content': record_info['content'],
'content_hash': dns_record_id.split(':')[1],
'associated_domains': list(record_info['domains']),
'source_data': record_info['raw_data']
}
# Create the DNS record node
self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata)
# Connect each domain to this DNS record
for domain_name in record_info['domains']:
if self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD,
record_info['confidence'], record_info['provider_name'],
record_info['raw_data']):
print(f"Added DNS record relationship: {domain_name} -> {dns_record_id}")
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs, {len(dns_records_to_create)} DNS records")
return discovered_domains, discovered_ips
def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
"""
Collect and organize metadata for a node based on provider responses.
"""
if provider_name == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target)
# For non-infrastructure DNS records, store the full content
if record_type in ['TXT', 'SPF', 'CAA']:
dns_entry = f"{record_type}: {value}"
else:
dns_entry = f"{record_type}: {value}"
if dns_entry not in metadata['dns_records']:
metadata['dns_records'].append(dns_entry)
elif provider_name == 'crtsh':
if rel_type == RelationshipType.SAN_CERTIFICATE:
# Handle certificate data storage on domain nodes
domain_certs = raw_data.get('domain_certificates', {})
# Store certificate information for this domain
if node_id in domain_certs:
cert_summary = domain_certs[node_id]
# Update domain metadata with certificate information
metadata['certificate_data'] = cert_summary
metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False)
# Add related domains from shared certificates
if target not in metadata.get('related_domains_san', []):
if 'related_domains_san' not in metadata:
metadata['related_domains_san'] = []
metadata['related_domains_san'].append(target)
# Store shared certificate details for forensic analysis
shared_certs = raw_data.get('shared_certificates', [])
if shared_certs and 'shared_certificate_details' not in metadata:
metadata['shared_certificate_details'] = shared_certs
elif provider_name == 'shodan':
# Merge Shodan data (avoid overwriting)
for key, value in raw_data.items():
if key not in metadata['shodan'] or not metadata['shodan'][key]:
metadata['shodan'][key] = value
elif provider_name == 'virustotal':
# Merge VirusTotal data
for key, value in raw_data.items():
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
metadata['virustotal'][key] = value
# Add passive DNS entries
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata['passive_dns']:
metadata['passive_dns'].append(passive_entry)
def _handle_large_entity(self, source_domain: str, relationships: list, rel_type: RelationshipType, provider_name: str):
"""
Handles the creation of a large entity node when a threshold is exceeded.
"""
print(f"Large number of {rel_type.name} relationships for {source_domain}. Creating a large entity node.")
entity_name = f"Large collection of {rel_type.name} for {source_domain}"
self.graph.add_node(entity_name, NodeType.LARGE_ENTITY, metadata={"count": len(relationships)})
self.graph.add_edge(source_domain, entity_name, rel_type, 0.9, provider_name, {"info": "Aggregated node"})
def _query_providers_for_ip(self, ip: str) -> None:
"""
Query all enabled providers for information about an IP address and collect comprehensive metadata.
"""
print(f"Querying {len(self.providers)} providers for IP: {ip}")
if not self.providers or self.stop_event.is_set():
return
# Comprehensive metadata collection for this IP
ip_metadata = {
'dns_records': [],
'passive_dns': [],
'shodan': {},
'virustotal': {},
'asn_data': {},
'hostnames': [],
}
all_relationships = [] # Store relationships with provider info
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider
for provider in self.providers
}
for future in as_completed(future_to_provider):
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}")
for source, target, rel_type, confidence, raw_data in relationships:
# Add provider info to the relationship
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
all_relationships.append(enhanced_rel)
# Collect metadata for the IP
self._collect_ip_metadata(ip, provider.get_name(), rel_type, target, raw_data, ip_metadata)
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
# Update the IP node with comprehensive metadata
self.graph.add_node(ip, NodeType.IP, metadata=ip_metadata)
# Process relationships with correct provider attribution
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
# Determine target node type
if _is_valid_domain(target):
target_node_type = NodeType.DOMAIN
elif target.startswith('AS'):
target_node_type = NodeType.ASN
else:
target_node_type = NodeType.IP
# Create/update target node
self.graph.add_node(target, target_node_type)
# Add relationship with correct provider attribution
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name}) from {provider_name}")
def _collect_ip_metadata(self, ip: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
"""
Collect and organize metadata for an IP node based on provider responses.
"""
if provider_name == 'dns':
if rel_type == RelationshipType.PTR_RECORD:
reverse_entry = f"PTR: {target}"
if reverse_entry not in metadata['dns_records']:
metadata['dns_records'].append(reverse_entry)
if target not in metadata['hostnames']:
metadata['hostnames'].append(target)
elif provider_name == 'shodan':
# Merge Shodan data
for key, value in raw_data.items():
if key not in metadata['shodan'] or not metadata['shodan'][key]:
metadata['shodan'][key] = value
# Collect hostname information
if 'hostname' in raw_data and raw_data['hostname'] not in metadata['hostnames']:
metadata['hostnames'].append(raw_data['hostname'])
if 'hostnames' in raw_data:
for hostname in raw_data['hostnames']:
if hostname not in metadata['hostnames']:
metadata['hostnames'].append(hostname)
elif provider_name == 'virustotal':
# Merge VirusTotal data
for key, value in raw_data.items():
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
metadata['virustotal'][key] = value
# Add passive DNS entries
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata['passive_dns']:
metadata['passive_dns'].append(passive_entry)
# Handle ASN relationships
if rel_type == RelationshipType.ASN_MEMBERSHIP:
metadata['asn_data'] = {
'asn': target,
'description': raw_data.get('org', ''),
'isp': raw_data.get('isp', ''),
'country': raw_data.get('country', '')
}
def _safe_provider_query_domain(self, provider, domain: str):
"""Safely query provider for domain with error handling."""
if self.stop_event.is_set():
return []
try:
return provider.query_domain(domain)
except Exception as e:
print(f"Provider {provider.get_name()} query_domain failed: {e}")
return []
def _safe_provider_query_ip(self, provider, ip: str):
"""Safely query provider for IP with error handling."""
if self.stop_event.is_set():
return []
try:
return provider.query_ip(ip)
except Exception as e:
print(f"Provider {provider.get_name()} query_ip failed: {e}")
return []
def stop_scan(self) -> bool:
"""
Request immediate scan termination with aggressive cancellation.
"""
try:
if self.status == ScanStatus.RUNNING:
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
# Signal all threads to stop
self.stop_event.set()
# Close HTTP sessions in all providers to terminate ongoing requests
for provider in self.providers:
try:
if hasattr(provider, 'session'):
provider.session.close()
print(f"Closed HTTP session for provider: {provider.get_name()}")
except Exception as e:
print(f"Error closing session for {provider.get_name()}: {e}")
# Shutdown executor immediately with cancel_futures=True
if self.executor:
print("Shutting down executor with immediate cancellation...")
self.executor.shutdown(wait=False, cancel_futures=True)
# Give threads a moment to respond to cancellation, then force status change
threading.Timer(2.0, self._force_stop_completion).start()
print("Immediate termination requested - ongoing requests will be cancelled")
return True
print("No active scan to stop")
return False
except Exception as e:
print(f"ERROR: Exception in stop_scan: {e}")
traceback.print_exc()
return False
def _force_stop_completion(self):
"""Force completion of stop operation after timeout."""
if self.status == ScanStatus.RUNNING:
print("Forcing scan termination after timeout")
self.status = ScanStatus.STOPPED
self.logger.log_scan_complete()
def get_scan_status(self) -> Dict[str, Any]:
"""
Get current scan status and progress.
"""
try:
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'total_indicators_found': self.total_indicators_found,
'indicators_processed': self.indicators_processed,
'progress_percentage': self._calculate_progress(),
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics()
}
except Exception as e:
print(f"ERROR: Exception in get_scan_status: {e}")
traceback.print_exc()
return {
'status': 'error',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'current_indicator': '',
'total_indicators_found': 0,
'indicators_processed': 0,
'progress_percentage': 0.0,
'enabled_providers': [],
'graph_statistics': {}
}
def _calculate_progress(self) -> float:
"""Calculate scan progress percentage."""
if self.total_indicators_found == 0:
return 0.0
return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100)
def get_graph_data(self) -> Dict[str, Any]:
"""
Get current graph data for visualization.
"""
return self.graph.get_graph_data()
def export_results(self) -> Dict[str, Any]:
"""
Export complete scan results including graph and audit trail.
"""
graph_data = self.graph.export_json()
audit_trail = self.logger.export_audit_trail()
provider_stats = {}
for provider in self.providers:
provider_stats[provider.get_name()] = provider.get_statistics()
export_data = {
'scan_metadata': {
'target_domain': self.current_target,
'max_depth': self.max_depth,
'final_status': self.status,
'total_indicators_processed': self.indicators_processed,
'enabled_providers': list(provider_stats.keys())
},
'graph_data': graph_data,
'forensic_audit': audit_trail,
'provider_statistics': provider_stats,
'scan_summary': self.logger.get_forensic_summary()
}
return export_data
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""
Get statistics for all providers.
"""
stats = {}
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats