868 lines
38 KiB
Python
868 lines
38 KiB
Python
"""
|
|
Main scanning orchestrator for DNSRecon.
|
|
Coordinates data gathering from multiple providers and builds the infrastructure graph.
|
|
"""
|
|
|
|
import threading
|
|
import traceback
|
|
from typing import List, Set, Dict, Any, Tuple
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
|
|
from collections import defaultdict
|
|
|
|
from core.graph_manager import GraphManager, NodeType, RelationshipType
|
|
from core.logger import get_forensic_logger, new_session
|
|
from utils.helpers import _is_valid_ip, _is_valid_domain
|
|
from providers.crtsh_provider import CrtShProvider
|
|
from providers.dns_provider import DNSProvider
|
|
from providers.shodan_provider import ShodanProvider
|
|
from providers.virustotal_provider import VirusTotalProvider
|
|
from config import config
|
|
|
|
|
|
class ScanStatus:
|
|
"""Enumeration of scan statuses."""
|
|
IDLE = "idle"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class Scanner:
|
|
"""
|
|
Main scanning orchestrator for DNSRecon passive reconnaissance.
|
|
Now supports per-session configuration for multi-user isolation.
|
|
"""
|
|
|
|
def __init__(self, session_config=None):
|
|
"""Initialize scanner with session-specific configuration."""
|
|
print("Initializing Scanner instance...")
|
|
|
|
try:
|
|
# Use provided session config or create default
|
|
if session_config is None:
|
|
from core.session_config import create_session_config
|
|
session_config = create_session_config()
|
|
|
|
self.config = session_config
|
|
self.graph = GraphManager()
|
|
self.providers = []
|
|
self.status = ScanStatus.IDLE
|
|
self.current_target = None
|
|
self.current_depth = 0
|
|
self.max_depth = 2
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
|
|
# Scanning progress tracking
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.current_indicator = ""
|
|
|
|
# Concurrent processing configuration
|
|
self.max_workers = self.config.max_concurrent_requests
|
|
self.executor = None
|
|
|
|
# Initialize providers with session config
|
|
print("Calling _initialize_providers with session config...")
|
|
self._initialize_providers()
|
|
|
|
# Initialize logger
|
|
print("Initializing forensic logger...")
|
|
self.logger = get_forensic_logger()
|
|
|
|
print("Scanner initialization complete")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Scanner initialization failed: {e}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
def _initialize_providers(self) -> None:
|
|
"""Initialize all available providers based on session configuration."""
|
|
self.providers = []
|
|
|
|
print("Initializing providers with session config...")
|
|
|
|
# Always add free providers
|
|
free_providers = [
|
|
('crtsh', CrtShProvider),
|
|
('dns', DNSProvider)
|
|
]
|
|
|
|
for provider_name, provider_class in free_providers:
|
|
if self.config.is_provider_enabled(provider_name):
|
|
try:
|
|
# Pass session config to provider
|
|
provider = provider_class(session_config=self.config)
|
|
if provider.is_available():
|
|
# Set the stop event for cancellation support
|
|
provider.set_stop_event(self.stop_event)
|
|
self.providers.append(provider)
|
|
print(f"✓ {provider_name.title()} provider initialized successfully for session")
|
|
else:
|
|
print(f"✗ {provider_name.title()} provider is not available")
|
|
except Exception as e:
|
|
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
|
|
traceback.print_exc()
|
|
|
|
# Add API key-dependent providers
|
|
api_providers = [
|
|
('shodan', ShodanProvider),
|
|
('virustotal', VirusTotalProvider)
|
|
]
|
|
|
|
for provider_name, provider_class in api_providers:
|
|
if self.config.is_provider_enabled(provider_name):
|
|
try:
|
|
# Pass session config to provider
|
|
provider = provider_class(session_config=self.config)
|
|
if provider.is_available():
|
|
# Set the stop event for cancellation support
|
|
provider.set_stop_event(self.stop_event)
|
|
self.providers.append(provider)
|
|
print(f"✓ {provider_name.title()} provider initialized successfully for session")
|
|
else:
|
|
print(f"✗ {provider_name.title()} provider is not available (API key required)")
|
|
except Exception as e:
|
|
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
|
|
traceback.print_exc()
|
|
|
|
print(f"Initialized {len(self.providers)} providers for session")
|
|
|
|
def update_session_config(self, new_config) -> None:
|
|
"""
|
|
Update session configuration and reinitialize providers.
|
|
|
|
Args:
|
|
new_config: New SessionConfig instance
|
|
"""
|
|
print("Updating session configuration...")
|
|
self.config = new_config
|
|
self.max_workers = self.config.max_concurrent_requests
|
|
self._initialize_providers()
|
|
print("Session configuration updated")
|
|
|
|
|
|
def start_scan(self, target_domain: str, max_depth: int = 2) -> bool:
|
|
"""
|
|
Start a new reconnaissance scan with concurrent processing.
|
|
Enhanced with better debugging and state validation.
|
|
|
|
Args:
|
|
target_domain: Initial domain to investigate
|
|
max_depth: Maximum recursion depth
|
|
|
|
Returns:
|
|
bool: True if scan started successfully
|
|
"""
|
|
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
|
|
print(f"Scanner status: {self.status}")
|
|
print(f"Target domain: '{target_domain}', Max depth: {max_depth}")
|
|
print(f"Available providers: {len(self.providers) if hasattr(self, 'providers') else 0}")
|
|
|
|
try:
|
|
if self.status == ScanStatus.RUNNING:
|
|
print(f"ERROR: Scan already running in scanner {id(self)}, rejecting new scan")
|
|
print(f"Current target: {self.current_target}")
|
|
print(f"Current depth: {self.current_depth}")
|
|
return False
|
|
|
|
# Check if we have any providers
|
|
if not hasattr(self, 'providers') or not self.providers:
|
|
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
|
|
return False
|
|
|
|
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
|
|
|
|
# Stop any existing scan thread
|
|
if self.scan_thread and self.scan_thread.is_alive():
|
|
print(f"Stopping existing scan thread in scanner {id(self)}...")
|
|
self.stop_event.set()
|
|
self.scan_thread.join(timeout=5.0)
|
|
if self.scan_thread.is_alive():
|
|
print(f"WARNING: Could not stop existing thread in scanner {id(self)}")
|
|
return False
|
|
|
|
# Reset state
|
|
print(f"Resetting scanner {id(self)} state...")
|
|
self.graph.clear()
|
|
self.current_target = target_domain.lower().strip()
|
|
self.max_depth = max_depth
|
|
self.current_depth = 0
|
|
self.stop_event.clear()
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.current_indicator = self.current_target
|
|
|
|
# Start new forensic session
|
|
print(f"Starting new forensic session for scanner {id(self)}...")
|
|
self.logger = new_session()
|
|
|
|
# Start scan in separate thread
|
|
print(f"Starting scan thread for scanner {id(self)}...")
|
|
self.scan_thread = threading.Thread(
|
|
target=self._execute_scan_async,
|
|
args=(self.current_target, max_depth),
|
|
daemon=True
|
|
)
|
|
self.scan_thread.start()
|
|
|
|
print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
async def _execute_scan_async(self, target_domain: str, max_depth: int) -> None:
|
|
"""
|
|
Execute the reconnaissance scan asynchronously with concurrent provider queries.
|
|
|
|
Args:
|
|
target_domain: Target domain to investigate
|
|
max_depth: Maximum recursion depth
|
|
"""
|
|
print(f"_execute_scan_async started for {target_domain} with depth {max_depth}")
|
|
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
|
|
|
try:
|
|
print("Setting status to RUNNING")
|
|
self.status = ScanStatus.RUNNING
|
|
|
|
# Log scan start
|
|
enabled_providers = [provider.get_name() for provider in self.providers]
|
|
self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
|
|
print(f"Logged scan start with providers: {enabled_providers}")
|
|
|
|
# Initialize with target domain
|
|
print(f"Adding target domain '{target_domain}' as initial node")
|
|
self.graph.add_node(target_domain, NodeType.DOMAIN)
|
|
|
|
# BFS-style exploration
|
|
current_level_domains = {target_domain}
|
|
processed_domains = set()
|
|
all_discovered_ips = set()
|
|
|
|
print("Starting BFS exploration...")
|
|
|
|
for depth in range(max_depth + 1):
|
|
if self.stop_event.is_set():
|
|
print(f"Stop requested at depth {depth}")
|
|
break
|
|
|
|
self.current_depth = depth
|
|
print(f"Processing depth level {depth} with {len(current_level_domains)} domains")
|
|
|
|
if not current_level_domains:
|
|
print("No domains to process at this level")
|
|
break
|
|
|
|
self.total_indicators_found += len(current_level_domains)
|
|
next_level_domains = set()
|
|
|
|
domain_results = self._process_domains_concurrent(current_level_domains, processed_domains)
|
|
|
|
for domain, discovered_domains, discovered_ips in domain_results:
|
|
if self.stop_event.is_set():
|
|
break
|
|
|
|
processed_domains.add(domain)
|
|
all_discovered_ips.update(discovered_ips)
|
|
|
|
if depth < max_depth:
|
|
for discovered_domain in discovered_domains:
|
|
if discovered_domain not in processed_domains:
|
|
next_level_domains.add(discovered_domain)
|
|
print(f"Adding {discovered_domain} to next level")
|
|
|
|
if self.stop_event.is_set():
|
|
break
|
|
|
|
if all_discovered_ips:
|
|
print(f"Processing {len(all_discovered_ips)} discovered IP addresses")
|
|
self._process_ips_concurrent(all_discovered_ips)
|
|
|
|
current_level_domains = next_level_domains
|
|
print(f"Completed depth {depth}, {len(next_level_domains)} domains for next level")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Scan execution failed with error: {e}")
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self.logger.logger.error(f"Scan failed: {e}")
|
|
finally:
|
|
if self.stop_event.is_set():
|
|
self.status = ScanStatus.STOPPED
|
|
print("Scan completed with STOPPED status")
|
|
else:
|
|
self.status = ScanStatus.COMPLETED
|
|
print("Scan completed with COMPLETED status")
|
|
|
|
self.logger.log_scan_complete()
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
|
|
stats = self.graph.get_statistics()
|
|
print("Final scan statistics:")
|
|
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
|
|
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
|
|
print(f" - Domains processed: {len(processed_domains)}")
|
|
print(f" - IPs discovered: {len(all_discovered_ips)}")
|
|
|
|
|
|
def _process_domains_concurrent(self, domains: Set[str], processed_domains: Set[str]) -> List[Tuple[str, Set[str], Set[str]]]:
|
|
"""
|
|
Process multiple domains concurrently using thread pool.
|
|
"""
|
|
results = []
|
|
domains_to_process = domains - processed_domains
|
|
if not domains_to_process:
|
|
return results
|
|
|
|
print(f"Processing {len(domains_to_process)} domains concurrently with {self.max_workers} workers")
|
|
|
|
future_to_domain = {
|
|
self.executor.submit(self._query_providers_for_domain, domain): domain
|
|
for domain in domains_to_process
|
|
}
|
|
|
|
for future in as_completed(future_to_domain):
|
|
if self.stop_event.is_set():
|
|
future.cancel()
|
|
continue
|
|
domain = future_to_domain[future]
|
|
try:
|
|
discovered_domains, discovered_ips = future.result()
|
|
results.append((domain, discovered_domains, discovered_ips))
|
|
self.indicators_processed += 1
|
|
print(f"Completed processing domain: {domain} ({len(discovered_domains)} domains, {len(discovered_ips)} IPs)")
|
|
except (Exception, CancelledError) as e:
|
|
print(f"Error processing domain {domain}: {e}")
|
|
return results
|
|
|
|
def _process_ips_concurrent(self, ips: Set[str]) -> None:
|
|
"""
|
|
Process multiple IP addresses concurrently.
|
|
"""
|
|
if not ips or self.stop_event.is_set():
|
|
return
|
|
print(f"Processing {len(ips)} IP addresses concurrently")
|
|
future_to_ip = {
|
|
self.executor.submit(self._query_providers_for_ip, ip): ip
|
|
for ip in ips
|
|
}
|
|
for future in as_completed(future_to_ip):
|
|
if self.stop_event.is_set():
|
|
future.cancel()
|
|
continue
|
|
ip = future_to_ip[future]
|
|
try:
|
|
future.result() # Just wait for completion
|
|
print(f"Completed processing IP: {ip}")
|
|
except (Exception, CancelledError) as e:
|
|
print(f"Error processing IP {ip}: {e}")
|
|
|
|
def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]:
|
|
"""
|
|
Query all enabled providers for information about a domain and collect comprehensive metadata.
|
|
Creates appropriate node types and relationships based on discovered data.
|
|
"""
|
|
print(f"Querying {len(self.providers)} providers for domain: {domain}")
|
|
discovered_domains = set()
|
|
discovered_ips = set()
|
|
all_relationships = []
|
|
|
|
# Comprehensive metadata collection for this domain
|
|
domain_metadata = {
|
|
'dns_records': [],
|
|
'related_domains_san': [],
|
|
'shodan': {},
|
|
'virustotal': {},
|
|
'certificate_data': {},
|
|
'passive_dns': [],
|
|
}
|
|
|
|
if not self.providers or self.stop_event.is_set():
|
|
return discovered_domains, discovered_ips
|
|
|
|
# Query all providers concurrently
|
|
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
|
|
future_to_provider = {
|
|
provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider
|
|
for provider in self.providers
|
|
}
|
|
|
|
for future in as_completed(future_to_provider):
|
|
if self.stop_event.is_set():
|
|
future.cancel()
|
|
continue
|
|
|
|
provider = future_to_provider[future]
|
|
try:
|
|
relationships = future.result()
|
|
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships")
|
|
|
|
# Process relationships and collect metadata
|
|
for rel in relationships:
|
|
source, target, rel_type, confidence, raw_data = rel
|
|
|
|
# Add provider info to the relationship
|
|
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
|
|
all_relationships.append(enhanced_rel)
|
|
|
|
# Collect metadata based on provider and relationship type
|
|
self._collect_node_metadata(domain, provider.get_name(), rel_type, target, raw_data, domain_metadata)
|
|
|
|
except (Exception, CancelledError) as e:
|
|
print(f"Provider {provider.get_name()} failed for {domain}: {e}")
|
|
|
|
# Add the domain node with comprehensive metadata
|
|
self.graph.add_node(domain, NodeType.DOMAIN, metadata=domain_metadata)
|
|
|
|
# Group relationships by type for large entity handling
|
|
relationships_by_type = defaultdict(list)
|
|
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
|
|
relationships_by_type[rel_type].append((source, target, rel_type, confidence, raw_data, provider_name))
|
|
|
|
# Handle large entities (only for SAN certificates currently)
|
|
for rel_type, relationships in relationships_by_type.items():
|
|
if len(relationships) > config.large_entity_threshold and rel_type == RelationshipType.SAN_CERTIFICATE:
|
|
first_provider = relationships[0][5] if relationships else "multiple_providers"
|
|
self._handle_large_entity(domain, relationships, rel_type, first_provider)
|
|
# Remove these relationships from further processing
|
|
all_relationships = [rel for rel in all_relationships if not (rel[2] == rel_type and len(relationships_by_type[rel_type]) > config.large_entity_threshold)]
|
|
|
|
# Track DNS records to create (avoid duplicates)
|
|
dns_records_to_create = {}
|
|
|
|
# Process remaining relationships
|
|
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
|
|
if self.stop_event.is_set():
|
|
break
|
|
|
|
# Determine how to handle the target based on relationship type and content
|
|
if _is_valid_ip(target):
|
|
# Create IP node and relationship
|
|
self.graph.add_node(target, NodeType.IP)
|
|
|
|
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
|
|
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
|
|
|
|
# Add to recursion if it's a direct resolution
|
|
if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]:
|
|
discovered_ips.add(target)
|
|
|
|
elif target.startswith('AS') and target[2:].isdigit():
|
|
# Create ASN node and relationship
|
|
self.graph.add_node(target, NodeType.ASN)
|
|
|
|
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
|
|
print(f"Added ASN relationship: {source} -> {target} ({rel_type.relationship_name})")
|
|
|
|
elif _is_valid_domain(target):
|
|
# Create domain node and relationship
|
|
self.graph.add_node(target, NodeType.DOMAIN)
|
|
|
|
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
|
|
print(f"Added domain relationship: {source} -> {target} ({rel_type.relationship_name})")
|
|
|
|
# Add to recursion for specific relationship types
|
|
recurse_types = [
|
|
RelationshipType.CNAME_RECORD,
|
|
RelationshipType.MX_RECORD,
|
|
RelationshipType.SAN_CERTIFICATE,
|
|
RelationshipType.NS_RECORD,
|
|
RelationshipType.PASSIVE_DNS
|
|
]
|
|
if rel_type in recurse_types:
|
|
discovered_domains.add(target)
|
|
|
|
else:
|
|
# Handle DNS record content (TXT, SPF, CAA, etc.)
|
|
dns_record_types = [
|
|
RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD,
|
|
RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD,
|
|
RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD,
|
|
RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD,
|
|
RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD
|
|
]
|
|
|
|
if rel_type in dns_record_types:
|
|
# Create normalized DNS record identifier
|
|
record_type = rel_type.relationship_name.upper().replace('_RECORD', '')
|
|
record_content = target.strip()
|
|
|
|
# Create a unique identifier for this DNS record
|
|
content_hash = hash(record_content) & 0x7FFFFFFF
|
|
dns_record_id = f"{record_type}:{content_hash}"
|
|
|
|
# Track this DNS record for creation (avoid duplicates)
|
|
if dns_record_id not in dns_records_to_create:
|
|
dns_records_to_create[dns_record_id] = {
|
|
'content': record_content,
|
|
'type': record_type,
|
|
'domains': set(),
|
|
'raw_data': raw_data,
|
|
'provider_name': provider_name,
|
|
'confidence': confidence
|
|
}
|
|
|
|
# Add this domain to the DNS record's domain list
|
|
dns_records_to_create[dns_record_id]['domains'].add(source)
|
|
|
|
print(f"DNS record tracked: {source} -> {record_type} (content length: {len(record_content)})")
|
|
else:
|
|
# For other non-infrastructure targets, log but don't create nodes
|
|
print(f"Non-infrastructure relationship stored as metadata: {source} - {rel_type.relationship_name}: {target[:100]}")
|
|
|
|
# Create DNS record nodes and their relationships
|
|
for dns_record_id, record_info in dns_records_to_create.items():
|
|
if self.stop_event.is_set():
|
|
break
|
|
|
|
record_metadata = {
|
|
'record_type': record_info['type'],
|
|
'content': record_info['content'],
|
|
'content_hash': dns_record_id.split(':')[1],
|
|
'associated_domains': list(record_info['domains']),
|
|
'source_data': record_info['raw_data']
|
|
}
|
|
|
|
# Create the DNS record node
|
|
self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata)
|
|
|
|
# Connect each domain to this DNS record
|
|
for domain_name in record_info['domains']:
|
|
if self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD,
|
|
record_info['confidence'], record_info['provider_name'],
|
|
record_info['raw_data']):
|
|
print(f"Added DNS record relationship: {domain_name} -> {dns_record_id}")
|
|
|
|
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs, {len(dns_records_to_create)} DNS records")
|
|
return discovered_domains, discovered_ips
|
|
|
|
def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType,
|
|
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
|
|
"""
|
|
Collect and organize metadata for a node based on provider responses.
|
|
"""
|
|
if provider_name == 'dns':
|
|
record_type = raw_data.get('query_type', 'UNKNOWN')
|
|
value = raw_data.get('value', target)
|
|
|
|
# For non-infrastructure DNS records, store the full content
|
|
if record_type in ['TXT', 'SPF', 'CAA']:
|
|
dns_entry = f"{record_type}: {value}"
|
|
else:
|
|
dns_entry = f"{record_type}: {value}"
|
|
|
|
if dns_entry not in metadata['dns_records']:
|
|
metadata['dns_records'].append(dns_entry)
|
|
|
|
elif provider_name == 'crtsh':
|
|
if rel_type == RelationshipType.SAN_CERTIFICATE:
|
|
# Handle certificate data storage on domain nodes
|
|
domain_certs = raw_data.get('domain_certificates', {})
|
|
|
|
# Store certificate information for this domain
|
|
if node_id in domain_certs:
|
|
cert_summary = domain_certs[node_id]
|
|
|
|
# Update domain metadata with certificate information
|
|
metadata['certificate_data'] = cert_summary
|
|
metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False)
|
|
|
|
# Add related domains from shared certificates
|
|
if target not in metadata.get('related_domains_san', []):
|
|
if 'related_domains_san' not in metadata:
|
|
metadata['related_domains_san'] = []
|
|
metadata['related_domains_san'].append(target)
|
|
|
|
# Store shared certificate details for forensic analysis
|
|
shared_certs = raw_data.get('shared_certificates', [])
|
|
if shared_certs and 'shared_certificate_details' not in metadata:
|
|
metadata['shared_certificate_details'] = shared_certs
|
|
|
|
elif provider_name == 'shodan':
|
|
# Merge Shodan data (avoid overwriting)
|
|
for key, value in raw_data.items():
|
|
if key not in metadata['shodan'] or not metadata['shodan'][key]:
|
|
metadata['shodan'][key] = value
|
|
|
|
elif provider_name == 'virustotal':
|
|
# Merge VirusTotal data
|
|
for key, value in raw_data.items():
|
|
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
|
|
metadata['virustotal'][key] = value
|
|
|
|
# Add passive DNS entries
|
|
if rel_type == RelationshipType.PASSIVE_DNS:
|
|
passive_entry = f"Passive DNS: {target}"
|
|
if passive_entry not in metadata['passive_dns']:
|
|
metadata['passive_dns'].append(passive_entry)
|
|
|
|
def _handle_large_entity(self, source_domain: str, relationships: list, rel_type: RelationshipType, provider_name: str):
|
|
"""
|
|
Handles the creation of a large entity node when a threshold is exceeded.
|
|
"""
|
|
print(f"Large number of {rel_type.name} relationships for {source_domain}. Creating a large entity node.")
|
|
entity_name = f"Large collection of {rel_type.name} for {source_domain}"
|
|
self.graph.add_node(entity_name, NodeType.LARGE_ENTITY, metadata={"count": len(relationships)})
|
|
self.graph.add_edge(source_domain, entity_name, rel_type, 0.9, provider_name, {"info": "Aggregated node"})
|
|
|
|
def _query_providers_for_ip(self, ip: str) -> None:
|
|
"""
|
|
Query all enabled providers for information about an IP address and collect comprehensive metadata.
|
|
"""
|
|
print(f"Querying {len(self.providers)} providers for IP: {ip}")
|
|
if not self.providers or self.stop_event.is_set():
|
|
return
|
|
|
|
# Comprehensive metadata collection for this IP
|
|
ip_metadata = {
|
|
'dns_records': [],
|
|
'passive_dns': [],
|
|
'shodan': {},
|
|
'virustotal': {},
|
|
'asn_data': {},
|
|
'hostnames': [],
|
|
}
|
|
|
|
all_relationships = [] # Store relationships with provider info
|
|
|
|
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
|
|
future_to_provider = {
|
|
provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider
|
|
for provider in self.providers
|
|
}
|
|
for future in as_completed(future_to_provider):
|
|
if self.stop_event.is_set():
|
|
future.cancel()
|
|
continue
|
|
provider = future_to_provider[future]
|
|
try:
|
|
relationships = future.result()
|
|
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}")
|
|
|
|
for source, target, rel_type, confidence, raw_data in relationships:
|
|
# Add provider info to the relationship
|
|
enhanced_rel = (source, target, rel_type, confidence, raw_data, provider.get_name())
|
|
all_relationships.append(enhanced_rel)
|
|
|
|
# Collect metadata for the IP
|
|
self._collect_ip_metadata(ip, provider.get_name(), rel_type, target, raw_data, ip_metadata)
|
|
|
|
except (Exception, CancelledError) as e:
|
|
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
|
|
|
|
# Update the IP node with comprehensive metadata
|
|
self.graph.add_node(ip, NodeType.IP, metadata=ip_metadata)
|
|
|
|
# Process relationships with correct provider attribution
|
|
for source, target, rel_type, confidence, raw_data, provider_name in all_relationships:
|
|
# Determine target node type
|
|
if _is_valid_domain(target):
|
|
target_node_type = NodeType.DOMAIN
|
|
elif target.startswith('AS'):
|
|
target_node_type = NodeType.ASN
|
|
else:
|
|
target_node_type = NodeType.IP
|
|
|
|
# Create/update target node
|
|
self.graph.add_node(target, target_node_type)
|
|
|
|
# Add relationship with correct provider attribution
|
|
if self.graph.add_edge(source, target, rel_type, confidence, provider_name, raw_data):
|
|
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name}) from {provider_name}")
|
|
|
|
def _collect_ip_metadata(self, ip: str, provider_name: str, rel_type: RelationshipType,
|
|
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
|
|
"""
|
|
Collect and organize metadata for an IP node based on provider responses.
|
|
"""
|
|
if provider_name == 'dns':
|
|
if rel_type == RelationshipType.PTR_RECORD:
|
|
reverse_entry = f"PTR: {target}"
|
|
if reverse_entry not in metadata['dns_records']:
|
|
metadata['dns_records'].append(reverse_entry)
|
|
if target not in metadata['hostnames']:
|
|
metadata['hostnames'].append(target)
|
|
|
|
elif provider_name == 'shodan':
|
|
# Merge Shodan data
|
|
for key, value in raw_data.items():
|
|
if key not in metadata['shodan'] or not metadata['shodan'][key]:
|
|
metadata['shodan'][key] = value
|
|
|
|
# Collect hostname information
|
|
if 'hostname' in raw_data and raw_data['hostname'] not in metadata['hostnames']:
|
|
metadata['hostnames'].append(raw_data['hostname'])
|
|
if 'hostnames' in raw_data:
|
|
for hostname in raw_data['hostnames']:
|
|
if hostname not in metadata['hostnames']:
|
|
metadata['hostnames'].append(hostname)
|
|
|
|
elif provider_name == 'virustotal':
|
|
# Merge VirusTotal data
|
|
for key, value in raw_data.items():
|
|
if key not in metadata['virustotal'] or not metadata['virustotal'][key]:
|
|
metadata['virustotal'][key] = value
|
|
|
|
# Add passive DNS entries
|
|
if rel_type == RelationshipType.PASSIVE_DNS:
|
|
passive_entry = f"Passive DNS: {target}"
|
|
if passive_entry not in metadata['passive_dns']:
|
|
metadata['passive_dns'].append(passive_entry)
|
|
|
|
# Handle ASN relationships
|
|
if rel_type == RelationshipType.ASN_MEMBERSHIP:
|
|
metadata['asn_data'] = {
|
|
'asn': target,
|
|
'description': raw_data.get('org', ''),
|
|
'isp': raw_data.get('isp', ''),
|
|
'country': raw_data.get('country', '')
|
|
}
|
|
|
|
|
|
def _safe_provider_query_domain(self, provider, domain: str):
|
|
"""Safely query provider for domain with error handling."""
|
|
if self.stop_event.is_set():
|
|
return []
|
|
try:
|
|
return provider.query_domain(domain)
|
|
except Exception as e:
|
|
print(f"Provider {provider.get_name()} query_domain failed: {e}")
|
|
return []
|
|
|
|
def _safe_provider_query_ip(self, provider, ip: str):
|
|
"""Safely query provider for IP with error handling."""
|
|
if self.stop_event.is_set():
|
|
return []
|
|
try:
|
|
return provider.query_ip(ip)
|
|
except Exception as e:
|
|
print(f"Provider {provider.get_name()} query_ip failed: {e}")
|
|
return []
|
|
|
|
def stop_scan(self) -> bool:
|
|
"""
|
|
Request immediate scan termination with aggressive cancellation.
|
|
"""
|
|
try:
|
|
if self.status == ScanStatus.RUNNING:
|
|
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
|
|
|
|
# Signal all threads to stop
|
|
self.stop_event.set()
|
|
|
|
# Close HTTP sessions in all providers to terminate ongoing requests
|
|
for provider in self.providers:
|
|
try:
|
|
if hasattr(provider, 'session'):
|
|
provider.session.close()
|
|
print(f"Closed HTTP session for provider: {provider.get_name()}")
|
|
except Exception as e:
|
|
print(f"Error closing session for {provider.get_name()}: {e}")
|
|
|
|
# Shutdown executor immediately with cancel_futures=True
|
|
if self.executor:
|
|
print("Shutting down executor with immediate cancellation...")
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
|
|
# Give threads a moment to respond to cancellation, then force status change
|
|
threading.Timer(2.0, self._force_stop_completion).start()
|
|
|
|
print("Immediate termination requested - ongoing requests will be cancelled")
|
|
return True
|
|
print("No active scan to stop")
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Exception in stop_scan: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def _force_stop_completion(self):
|
|
"""Force completion of stop operation after timeout."""
|
|
if self.status == ScanStatus.RUNNING:
|
|
print("Forcing scan termination after timeout")
|
|
self.status = ScanStatus.STOPPED
|
|
self.logger.log_scan_complete()
|
|
|
|
def get_scan_status(self) -> Dict[str, Any]:
|
|
"""
|
|
Get current scan status and progress.
|
|
"""
|
|
try:
|
|
return {
|
|
'status': self.status,
|
|
'target_domain': self.current_target,
|
|
'current_depth': self.current_depth,
|
|
'max_depth': self.max_depth,
|
|
'current_indicator': self.current_indicator,
|
|
'total_indicators_found': self.total_indicators_found,
|
|
'indicators_processed': self.indicators_processed,
|
|
'progress_percentage': self._calculate_progress(),
|
|
'enabled_providers': [provider.get_name() for provider in self.providers],
|
|
'graph_statistics': self.graph.get_statistics()
|
|
}
|
|
except Exception as e:
|
|
print(f"ERROR: Exception in get_scan_status: {e}")
|
|
traceback.print_exc()
|
|
return {
|
|
'status': 'error',
|
|
'target_domain': None,
|
|
'current_depth': 0,
|
|
'max_depth': 0,
|
|
'current_indicator': '',
|
|
'total_indicators_found': 0,
|
|
'indicators_processed': 0,
|
|
'progress_percentage': 0.0,
|
|
'enabled_providers': [],
|
|
'graph_statistics': {}
|
|
}
|
|
|
|
def _calculate_progress(self) -> float:
|
|
"""Calculate scan progress percentage."""
|
|
if self.total_indicators_found == 0:
|
|
return 0.0
|
|
return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100)
|
|
|
|
def get_graph_data(self) -> Dict[str, Any]:
|
|
"""
|
|
Get current graph data for visualization.
|
|
"""
|
|
return self.graph.get_graph_data()
|
|
|
|
def export_results(self) -> Dict[str, Any]:
|
|
"""
|
|
Export complete scan results including graph and audit trail.
|
|
"""
|
|
graph_data = self.graph.export_json()
|
|
audit_trail = self.logger.export_audit_trail()
|
|
provider_stats = {}
|
|
for provider in self.providers:
|
|
provider_stats[provider.get_name()] = provider.get_statistics()
|
|
export_data = {
|
|
'scan_metadata': {
|
|
'target_domain': self.current_target,
|
|
'max_depth': self.max_depth,
|
|
'final_status': self.status,
|
|
'total_indicators_processed': self.indicators_processed,
|
|
'enabled_providers': list(provider_stats.keys())
|
|
},
|
|
'graph_data': graph_data,
|
|
'forensic_audit': audit_trail,
|
|
'provider_statistics': provider_stats,
|
|
'scan_summary': self.logger.get_forensic_summary()
|
|
}
|
|
return export_data
|
|
|
|
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Get statistics for all providers.
|
|
"""
|
|
stats = {}
|
|
for provider in self.providers:
|
|
stats[provider.get_name()] = provider.get_statistics()
|
|
return stats |