many improvements

This commit is contained in:
overcuriousity
2025-09-10 22:34:58 +02:00
parent 709d3b9f3d
commit db2101d814
7 changed files with 192 additions and 237 deletions

View File

@@ -19,6 +19,7 @@ class NodeType(Enum):
IP = "ip"
CERTIFICATE = "certificate"
ASN = "asn"
LARGE_ENTITY = "large_entity"
class RelationshipType(Enum):

View File

@@ -7,7 +7,7 @@ import threading
import time
import traceback
from typing import List, Set, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
from core.graph_manager import GraphManager, NodeType, RelationshipType
from core.logger import get_forensic_logger, new_session
@@ -36,7 +36,7 @@ class Scanner:
def __init__(self):
"""Initialize scanner with all available providers and empty graph."""
print("Initializing Scanner instance...")
try:
self.graph = GraphManager()
self.providers = []
@@ -44,16 +44,17 @@ class Scanner:
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_requested = False
self.stop_event = threading.Event() # Use a threading.Event for safer signaling
self.scan_thread = None
# Scanning progress tracking
self.total_indicators_found = 0
self.indicators_processed = 0
self.current_indicator = ""
# Concurrent processing configuration
self.max_workers = config.max_concurrent_requests
self.executor = None # Keep a reference to the executor
# Initialize providers
print("Calling _initialize_providers...")
@@ -62,9 +63,9 @@ class Scanner:
# Initialize logger
print("Initializing forensic logger...")
self.logger = get_forensic_logger()
print("Scanner initialization complete")
except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc()
@@ -81,7 +82,7 @@ class Scanner:
('crtsh', CrtShProvider),
('dns', DNSProvider)
]
for provider_name, provider_class in free_providers:
if config.is_provider_enabled(provider_name):
try:
@@ -100,7 +101,7 @@ class Scanner:
('shodan', ShodanProvider),
('virustotal', VirusTotalProvider)
]
for provider_name, provider_class in api_providers:
if config.is_provider_enabled(provider_name):
try:
@@ -128,7 +129,7 @@ class Scanner:
bool: True if scan started successfully
"""
print(f"Scanner.start_scan called with target='{target_domain}', depth={max_depth}")
try:
if self.status == ScanStatus.RUNNING:
print("Scan already running, rejecting new scan")
@@ -142,8 +143,8 @@ class Scanner:
# Stop any existing scan thread
if self.scan_thread and self.scan_thread.is_alive():
print("Stopping existing scan thread...")
self.stop_requested = True
self.scan_thread.join(timeout=2.0)
self.stop_event.set()
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
print("WARNING: Could not stop existing thread")
return False
@@ -154,7 +155,7 @@ class Scanner:
self.current_target = target_domain.lower().strip()
self.max_depth = max_depth
self.current_depth = 0
self.stop_requested = False
self.stop_event.clear()
self.total_indicators_found = 0
self.indicators_processed = 0
self.current_indicator = self.current_target
@@ -163,7 +164,7 @@ class Scanner:
print("Starting new forensic session...")
self.logger = new_session()
# Start scan in separate thread for Phase 2
# Start scan in separate thread
print("Starting scan thread...")
self.scan_thread = threading.Thread(
target=self._execute_scan_async,
@@ -171,9 +172,9 @@ class Scanner:
daemon=True
)
self.scan_thread.start()
return True
except Exception as e:
print(f"ERROR: Exception in start_scan: {e}")
traceback.print_exc()
@@ -188,6 +189,7 @@ class Scanner:
max_depth: Maximum recursion depth
"""
print(f"_execute_scan_async started for {target_domain} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
try:
print("Setting status to RUNNING")
@@ -202,7 +204,7 @@ class Scanner:
print(f"Adding target domain '{target_domain}' as initial node")
self.graph.add_node(target_domain, NodeType.DOMAIN)
# BFS-style exploration with depth limiting and concurrent processing
# BFS-style exploration
current_level_domains = {target_domain}
processed_domains = set()
all_discovered_ips = set()
@@ -210,7 +212,7 @@ class Scanner:
print(f"Starting BFS exploration...")
for depth in range(max_depth + 1):
if self.stop_requested:
if self.stop_event.is_set():
print(f"Stop requested at depth {depth}")
break
@@ -221,28 +223,27 @@ class Scanner:
print("No domains to process at this level")
break
# Update progress tracking
self.total_indicators_found += len(current_level_domains)
next_level_domains = set()
# Process domains at current depth level with concurrent queries
domain_results = self._process_domains_concurrent(current_level_domains, processed_domains)
for domain, discovered_domains, discovered_ips in domain_results:
if self.stop_requested:
if self.stop_event.is_set():
break
processed_domains.add(domain)
all_discovered_ips.update(discovered_ips)
# Add discovered domains to next level if not at max depth
if depth < max_depth:
for discovered_domain in discovered_domains:
if discovered_domain not in processed_domains:
next_level_domains.add(discovered_domain)
print(f"Adding {discovered_domain} to next level")
# Process discovered IPs concurrently
if self.stop_event.is_set():
break
if all_discovered_ips:
print(f"Processing {len(all_discovered_ips)} discovered IP addresses")
self._process_ips_concurrent(all_discovered_ips)
@@ -250,8 +251,13 @@ class Scanner:
current_level_domains = next_level_domains
print(f"Completed depth {depth}, {len(next_level_domains)} domains for next level")
# Finalize scan
if self.stop_requested:
except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
if self.stop_event.is_set():
self.status = ScanStatus.STOPPED
print("Scan completed with STOPPED status")
else:
@@ -259,8 +265,8 @@ class Scanner:
print("Scan completed with COMPLETED status")
self.logger.log_scan_complete()
self.executor.shutdown(wait=False, cancel_futures=True)
# Print final statistics
stats = self.graph.get_statistics()
print(f"Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
@@ -268,132 +274,97 @@ class Scanner:
print(f" - Domains processed: {len(processed_domains)}")
print(f" - IPs discovered: {len(all_discovered_ips)}")
except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
def _process_domains_concurrent(self, domains: Set[str], processed_domains: Set[str]) -> List[Tuple[str, Set[str], Set[str]]]:
"""
Process multiple domains concurrently using thread pool.
Args:
domains: Set of domains to process
processed_domains: Set of already processed domains
Returns:
List of tuples (domain, discovered_domains, discovered_ips)
"""
results = []
# Filter out already processed domains
domains_to_process = domains - processed_domains
if not domains_to_process:
return results
print(f"Processing {len(domains_to_process)} domains concurrently with {self.max_workers} workers")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# Submit all domain processing tasks
future_to_domain = {
executor.submit(self._query_providers_for_domain, domain): domain
for domain in domains_to_process
}
# Collect results as they complete
for future in as_completed(future_to_domain):
if self.stop_requested:
break
domain = future_to_domain[future]
try:
discovered_domains, discovered_ips = future.result()
results.append((domain, discovered_domains, discovered_ips))
self.indicators_processed += 1
print(f"Completed processing domain: {domain} ({len(discovered_domains)} domains, {len(discovered_ips)} IPs)")
except Exception as e:
print(f"Error processing domain {domain}: {e}")
traceback.print_exc()
future_to_domain = {
self.executor.submit(self._query_providers_for_domain, domain): domain
for domain in domains_to_process
}
for future in as_completed(future_to_domain):
if self.stop_event.is_set():
future.cancel()
continue
domain = future_to_domain[future]
try:
discovered_domains, discovered_ips = future.result()
results.append((domain, discovered_domains, discovered_ips))
self.indicators_processed += 1
print(f"Completed processing domain: {domain} ({len(discovered_domains)} domains, {len(discovered_ips)} IPs)")
except (Exception, CancelledError) as e:
print(f"Error processing domain {domain}: {e}")
return results
def _process_ips_concurrent(self, ips: Set[str]) -> None:
"""
Process multiple IP addresses concurrently.
Args:
ips: Set of IP addresses to process
"""
if not ips:
if not ips or self.stop_event.is_set():
return
print(f"Processing {len(ips)} IP addresses concurrently")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# Submit all IP processing tasks
future_to_ip = {
executor.submit(self._query_providers_for_ip, ip): ip
for ip in ips
}
# Collect results as they complete
for future in as_completed(future_to_ip):
if self.stop_requested:
break
ip = future_to_ip[future]
try:
future.result() # Just wait for completion
print(f"Completed processing IP: {ip}")
except Exception as e:
print(f"Error processing IP {ip}: {e}")
traceback.print_exc()
future_to_ip = {
self.executor.submit(self._query_providers_for_ip, ip): ip
for ip in ips
}
for future in as_completed(future_to_ip):
if self.stop_event.is_set():
future.cancel()
continue
ip = future_to_ip[future]
try:
future.result() # Just wait for completion
print(f"Completed processing IP: {ip}")
except (Exception, CancelledError) as e:
print(f"Error processing IP {ip}: {e}")
def _query_providers_for_domain(self, domain: str) -> Tuple[Set[str], Set[str]]:
"""
Query all enabled providers for information about a domain.
Args:
domain: Domain to investigate
Returns:
Tuple of (discovered_domains, discovered_ips)
"""
print(f"Querying {len(self.providers)} providers for domain: {domain}")
discovered_domains = set()
discovered_ips = set()
# Define a threshold for creating a "large entity" node
LARGE_ENTITY_THRESHOLD = 50
if not self.providers:
print("No providers available")
if not self.providers or self.stop_event.is_set():
return discovered_domains, discovered_ips
# Query providers concurrently for better performance
with ThreadPoolExecutor(max_workers=len(self.providers)) as executor:
# Submit queries for all providers
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
executor.submit(self._safe_provider_query_domain, provider, domain): provider
provider_executor.submit(self._safe_provider_query_domain, provider, domain): provider
for provider in self.providers
}
# Collect results as they complete
for future in as_completed(future_to_provider):
if self.stop_requested:
break
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships")
# Check if the number of relationships exceeds the threshold
if len(relationships) > LARGE_ENTITY_THRESHOLD:
# Create a single "large entity" node
large_entity_id = f"large_entity_{provider.get_name()}_{domain}"
self.graph.add_node(large_entity_id, NodeType.LARGE_ENTITY, metadata={'count': len(relationships), 'provider': provider.get_name()})
self.graph.add_edge(domain, large_entity_id, RelationshipType.PASSIVE_DNS, 1.0, provider.get_name(), {})
print(f"Created large entity node for {domain} from {provider.get_name()} with {len(relationships)} relationships")
continue # Skip adding individual nodes
for source, target, rel_type, confidence, raw_data in relationships:
# Determine node type based on target
if self._is_valid_ip(target):
target_node_type = NodeType.IP
discovered_ips.add(target)
@@ -401,22 +372,13 @@ class Scanner:
target_node_type = NodeType.DOMAIN
discovered_domains.add(target)
else:
# Could be ASN or certificate
target_node_type = NodeType.ASN if target.startswith('AS') else NodeType.CERTIFICATE
# Add nodes and relationship to graph
self.graph.add_node(source, NodeType.DOMAIN)
self.graph.add_node(target, target_node_type)
success = self.graph.add_edge(
source, target, rel_type, confidence,
provider.get_name(), raw_data
)
if success:
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added relationship: {source} -> {target} ({rel_type.relationship_name})")
except Exception as e:
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for {domain}: {e}")
print(f"Domain {domain}: discovered {len(discovered_domains)} domains, {len(discovered_ips)} IPs")
@@ -425,61 +387,43 @@ class Scanner:
def _query_providers_for_ip(self, ip: str) -> None:
"""
Query all enabled providers for information about an IP address.
Args:
ip: IP address to investigate
"""
print(f"Querying {len(self.providers)} providers for IP: {ip}")
if not self.providers:
print("No providers available")
if not self.providers or self.stop_event.is_set():
return
# Query providers concurrently
with ThreadPoolExecutor(max_workers=len(self.providers)) as executor:
# Submit queries for all providers
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor:
future_to_provider = {
executor.submit(self._safe_provider_query_ip, provider, ip): provider
provider_executor.submit(self._safe_provider_query_ip, provider, ip): provider
for provider in self.providers
}
# Collect results as they complete
for future in as_completed(future_to_provider):
if self.stop_requested:
break
if self.stop_event.is_set():
future.cancel()
continue
provider = future_to_provider[future]
try:
relationships = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for IP {ip}")
for source, target, rel_type, confidence, raw_data in relationships:
# Determine node type based on target
if self._is_valid_domain(target):
target_node_type = NodeType.DOMAIN
elif target.startswith('AS'):
target_node_type = NodeType.ASN
else:
target_node_type = NodeType.IP
# Add nodes and relationship to graph
self.graph.add_node(source, NodeType.IP)
self.graph.add_node(target, target_node_type)
success = self.graph.add_edge(
source, target, rel_type, confidence,
provider.get_name(), raw_data
)
if success:
if self.graph.add_edge(source, target, rel_type, confidence, provider.get_name(), raw_data):
print(f"Added IP relationship: {source} -> {target} ({rel_type.relationship_name})")
except Exception as e:
except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for IP {ip}: {e}")
def _safe_provider_query_domain(self, provider, domain: str):
"""Safely query provider for domain with error handling."""
if self.stop_event.is_set():
return []
try:
return provider.query_domain(domain)
except Exception as e:
@@ -488,6 +432,8 @@ class Scanner:
def _safe_provider_query_ip(self, provider, ip: str):
"""Safely query provider for IP with error handling."""
if self.stop_event.is_set():
return []
try:
return provider.query_ip(ip)
except Exception as e:
@@ -497,13 +443,10 @@ class Scanner:
def stop_scan(self) -> bool:
"""
Request scan termination.
Returns:
bool: True if stop request was accepted
"""
try:
if self.status == ScanStatus.RUNNING:
self.stop_requested = True
self.stop_event.set()
print("Scan stop requested")
return True
print("No active scan to stop")
@@ -516,9 +459,6 @@ class Scanner:
def get_scan_status(self) -> Dict[str, Any]:
"""
Get current scan status and progress.
Returns:
Dictionary containing scan status information
"""
try:
return {
@@ -558,31 +498,18 @@ class Scanner:
def get_graph_data(self) -> Dict[str, Any]:
"""
Get current graph data for visualization.
Returns:
Graph data formatted for frontend
"""
return self.graph.get_graph_data()
def export_results(self) -> Dict[str, Any]:
"""
Export complete scan results including graph and audit trail.
Returns:
Dictionary containing complete scan results
"""
# Get graph data
graph_data = self.graph.export_json()
# Get forensic audit trail
audit_trail = self.logger.export_audit_trail()
# Get provider statistics
provider_stats = {}
for provider in self.providers:
provider_stats[provider.get_name()] = provider.get_statistics()
# Combine all results
export_data = {
'scan_metadata': {
'target_domain': self.current_target,
@@ -596,18 +523,11 @@ class Scanner:
'provider_statistics': provider_stats,
'scan_summary': self.logger.get_forensic_summary()
}
return export_data
def remove_provider(self, provider_name: str) -> bool:
"""
Remove a provider from the scanner.
Args:
provider_name: Name of provider to remove
Returns:
bool: True if provider was removed
"""
for i, provider in enumerate(self.providers):
if provider.get_name() == provider_name:
@@ -618,63 +538,41 @@ class Scanner:
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""
Get statistics for all providers.
Returns:
Dictionary mapping provider names to their statistics
"""
stats = {}
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
Args:
domain: Domain string to validate
Returns:
True if domain appears valid
"""
if not domain or len(domain) > 253:
return False
# Check for valid characters and structure
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
Args:
ip: IP address string to validate
Returns:
True if IP appears valid
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False