This commit is contained in:
overcuriousity 2025-09-12 10:08:03 +02:00
parent df4e1703c4
commit f445187025
2 changed files with 390 additions and 272 deletions

View File

@ -1,6 +1,7 @@
""" """
Main scanning orchestrator for DNSRecon. Main scanning orchestrator for DNSRecon.
Coordinates data gathering from multiple providers and builds the infrastructure graph. Coordinates data gathering from multiple providers and builds the infrastructure graph.
REFACTORED: Simplified recursion with forensic provider state tracking.
""" """
import threading import threading
@ -8,6 +9,7 @@ import traceback
from typing import List, Set, Dict, Any, Tuple from typing import List, Set, Dict, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone
from core.graph_manager import GraphManager, NodeType, RelationshipType from core.graph_manager import GraphManager, NodeType, RelationshipType
from core.logger import get_forensic_logger, new_session from core.logger import get_forensic_logger, new_session
@ -15,7 +17,6 @@ from utils.helpers import _is_valid_ip, _is_valid_domain
from providers.crtsh_provider import CrtShProvider from providers.crtsh_provider import CrtShProvider
from providers.dns_provider import DNSProvider from providers.dns_provider import DNSProvider
from providers.shodan_provider import ShodanProvider from providers.shodan_provider import ShodanProvider
from providers.virustotal_provider import VirusTotalProvider
class ScanStatus: class ScanStatus:
@ -30,7 +31,7 @@ class ScanStatus:
class Scanner: class Scanner:
""" """
Main scanning orchestrator for DNSRecon passive reconnaissance. Main scanning orchestrator for DNSRecon passive reconnaissance.
Now supports per-session configuration for multi-user isolation. REFACTORED: Simplified recursion with forensic provider state tracking.
""" """
def __init__(self, session_config=None): def __init__(self, session_config=None):
@ -62,6 +63,14 @@ class Scanner:
self.max_workers = self.config.max_concurrent_requests self.max_workers = self.config.max_concurrent_requests
self.executor = None self.executor = None
# Provider eligibility mapping
self.provider_eligibility = {
'dns': {'domains': True, 'ips': True},
'crtsh': {'domains': True, 'ips': False},
'shodan': {'domains': True, 'ips': True},
'virustotal': {'domains': False, 'ips': False} # Disabled as requested
}
# Initialize providers with session config # Initialize providers with session config
print("Calling _initialize_providers with session config...") print("Calling _initialize_providers with session config...")
self._initialize_providers() self._initialize_providers()
@ -80,22 +89,21 @@ class Scanner:
def _initialize_providers(self) -> None: def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration.""" """Initialize all available providers based on session configuration."""
self.providers = [] self.providers = []
print("Initializing providers with session config...") print("Initializing providers with session config...")
# Always add free providers # Provider classes mapping
free_providers = [ provider_classes = {
('crtsh', CrtShProvider), 'dns': DNSProvider,
('dns', DNSProvider) 'crtsh': CrtShProvider,
] 'shodan': ShodanProvider,
# Skip virustotal as requested
}
for provider_name, provider_class in free_providers: for provider_name, provider_class in provider_classes.items():
if self.config.is_provider_enabled(provider_name): if self.config.is_provider_enabled(provider_name):
try: try:
# Pass session config to provider
provider = provider_class(session_config=self.config) provider = provider_class(session_config=self.config)
if provider.is_available(): if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event) provider.set_stop_event(self.stop_event)
self.providers.append(provider) self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully for session") print(f"{provider_name.title()} provider initialized successfully for session")
@ -105,70 +113,38 @@ class Scanner:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}") print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc() traceback.print_exc()
# Add API key-dependent providers
api_providers = [
('shodan', ShodanProvider),
('virustotal', VirusTotalProvider)
]
for provider_name, provider_class in api_providers:
if self.config.is_provider_enabled(provider_name):
try:
# Pass session config to provider
provider = provider_class(session_config=self.config)
if provider.is_available():
# Set the stop event for cancellation support
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider_name.title()} provider initialized successfully for session")
else:
print(f"{provider_name.title()} provider is not available (API key required)")
except Exception as e:
print(f"✗ Failed to initialize {provider_name.title()} provider: {e}")
traceback.print_exc()
print(f"Initialized {len(self.providers)} providers for session") print(f"Initialized {len(self.providers)} providers for session")
def update_session_config(self, new_config) -> None: def update_session_config(self, new_config) -> None:
""" """Update session configuration and reinitialize providers."""
Update session configuration and reinitialize providers.
Args:
new_config: New SessionConfig instance
"""
print("Updating session configuration...") print("Updating session configuration...")
self.config = new_config self.config = new_config
self.max_workers = self.config.max_concurrent_requests self.max_workers = self.config.max_concurrent_requests
self._initialize_providers() self._initialize_providers()
print("Session configuration updated") print("Session configuration updated")
def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool: def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool:
""" """Start a new reconnaissance scan with forensic tracking."""
Start a new reconnaissance scan.
Forcefully cleans up any previous scan thread before starting.
"""
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Initial scanner status: {self.status}") print(f"Initial scanner status: {self.status}")
# If a thread is still alive from a previous scan, we must wait for it to die. # Clean up previous scan thread if needed
if self.scan_thread and self.scan_thread.is_alive(): if self.scan_thread and self.scan_thread.is_alive():
print("A previous scan thread is still alive. Sending termination signal and waiting...") print("A previous scan thread is still alive. Sending termination signal and waiting...")
self.stop_scan() self.stop_scan()
self.scan_thread.join(10.0) # Wait up to 10 seconds self.scan_thread.join(10.0)
if self.scan_thread.is_alive(): if self.scan_thread.is_alive():
print("ERROR: The previous scan thread is unresponsive and could not be stopped. Please restart the application.") print("ERROR: The previous scan thread is unresponsive and could not be stopped.")
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
return False return False
print("Previous scan thread terminated successfully.") print("Previous scan thread terminated successfully.")
# Reset state for the new scan # Reset state for new scan
self.status = ScanStatus.IDLE self.status = ScanStatus.IDLE
print(f"Scanner state is now clean for a new scan.") print("Scanner state is now clean for a new scan.")
try: try:
# Check if we have any providers
if not hasattr(self, 'providers') or not self.providers: if not hasattr(self, 'providers') or not self.providers:
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan") print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
return False return False
@ -208,16 +184,13 @@ class Scanner:
return False return False
def _execute_scan(self, target_domain: str, max_depth: int) -> None: def _execute_scan(self, target_domain: str, max_depth: int) -> None:
""" """Execute the reconnaissance scan with simplified recursion and forensic tracking."""
Execute the reconnaissance scan with concurrent provider queries.
Args:
target_domain: Target domain to investigate
max_depth: Maximum recursion depth
"""
print(f"_execute_scan started for {target_domain} with depth {max_depth}") print(f"_execute_scan started for {target_domain} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
# Initialize variables outside try block
processed_targets = set() # Fix: Initialize here
try: try:
print("Setting status to RUNNING") print("Setting status to RUNNING")
self.status = ScanStatus.RUNNING self.status = ScanStatus.RUNNING
@ -227,15 +200,16 @@ class Scanner:
self.logger.log_scan_start(target_domain, max_depth, enabled_providers) self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
print(f"Logged scan start with providers: {enabled_providers}") print(f"Logged scan start with providers: {enabled_providers}")
# Initialize with target domain # Initialize with target domain and track it
print(f"Adding target domain '{target_domain}' as initial node") print(f"Adding target domain '{target_domain}' as initial node")
self.graph.add_node(target_domain, NodeType.DOMAIN) self.graph.add_node(target_domain, NodeType.DOMAIN)
self._initialize_provider_states(target_domain)
# BFS-style exploration # BFS-style exploration with simplified recursion
current_level_targets = {target_domain} current_level_targets = {target_domain}
processed_targets = set() all_discovered_targets = set() # Track all discovered targets for large entity detection
print("Starting BFS exploration...") print("Starting BFS exploration with simplified recursion...")
for depth in range(max_depth + 1): for depth in range(max_depth + 1):
if self.stop_event.is_set(): if self.stop_event.is_set():
@ -251,14 +225,20 @@ class Scanner:
self.total_indicators_found += len(current_level_targets) self.total_indicators_found += len(current_level_targets)
target_results = self._process_targets_concurrent(current_level_targets, processed_targets) # Process targets and collect newly discovered ones
target_results = self._process_targets_concurrent_forensic(
current_level_targets, processed_targets, all_discovered_targets, depth
)
next_level_targets = set() next_level_targets = set()
for target, new_targets in target_results: for target, new_targets in target_results:
processed_targets.add(target) processed_targets.add(target)
all_discovered_targets.update(new_targets)
# Simple recursion rule: only valid IPs and domains within depth limit
if depth < max_depth: if depth < max_depth:
for new_target in new_targets: for new_target in new_targets:
if new_target not in processed_targets: if self._should_recurse_on_target(new_target, processed_targets, all_discovered_targets):
next_level_targets.add(new_target) next_level_targets.add(new_target)
current_level_targets = next_level_targets current_level_targets = next_level_targets
@ -286,18 +266,58 @@ class Scanner:
print(f" - Total edges: {stats['basic_metrics']['total_edges']}") print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Targets processed: {len(processed_targets)}") print(f" - Targets processed: {len(processed_targets)}")
def _initialize_provider_states(self, target: str) -> None:
"""Initialize provider states for forensic tracking."""
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node()
return
def _process_targets_concurrent(self, targets: Set[str], processed_targets: Set[str]) -> List[Tuple[str, Set[str]]]: node_data = self.graph.graph.nodes[target]
"""Process multiple targets (domains or IPs) concurrently using a thread pool.""" if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
def _should_recurse_on_target(self, target: str, processed_targets: Set[str], all_discovered: Set[str]) -> bool:
"""
Simplified recursion logic: only recurse on valid IPs and domains that haven't been processed.
FORENSIC PRINCIPLE: Clear, simple rules for what gets recursed.
"""
# Don't recurse on already processed targets
if target in processed_targets:
return False
# Only recurse on valid IPs and domains
if not (_is_valid_ip(target) or _is_valid_domain(target)):
return False
# Don't recurse on targets contained in large entities
if self._is_in_large_entity(target):
return False
return True
def _is_in_large_entity(self, target: str) -> bool:
"""Check if a target is contained within a large entity node."""
for node_id, node_data in self.graph.graph.nodes(data=True):
if node_data.get('type') == NodeType.LARGE_ENTITY.value:
metadata = node_data.get('metadata', {})
contained_nodes = metadata.get('nodes', [])
if target in contained_nodes:
return True
return False
def _process_targets_concurrent_forensic(self, targets: Set[str], processed_targets: Set[str],
all_discovered: Set[str], current_depth: int) -> List[Tuple[str, Set[str]]]:
"""Process multiple targets concurrently with forensic provider state tracking."""
results = [] results = []
targets_to_process = targets - processed_targets targets_to_process = targets - processed_targets
if not targets_to_process: if not targets_to_process:
return results return results
print(f"Processing {len(targets_to_process)} targets concurrently with {self.max_workers} workers") print(f"Processing {len(targets_to_process)} targets concurrently with forensic tracking")
future_to_target = { future_to_target = {
self.executor.submit(self._query_providers_for_target, target): target self.executor.submit(self._query_providers_forensic, target, current_depth): target
for target in targets_to_process for target in targets_to_process
} }
@ -313,29 +333,37 @@ class Scanner:
print(f"Completed processing target: {target} (found {len(new_targets)} new targets)") print(f"Completed processing target: {target} (found {len(new_targets)} new targets)")
except (Exception, CancelledError) as e: except (Exception, CancelledError) as e:
print(f"Error processing target {target}: {e}") print(f"Error processing target {target}: {e}")
self._log_target_processing_error(target, str(e))
return results return results
def _query_providers_for_target(self, target: str) -> Set[str]: def _query_providers_forensic(self, target: str, current_depth: int) -> Set[str]:
""" """
Query all enabled providers for information about a target (domain or IP) and collect comprehensive metadata. Query providers for a target with forensic state tracking and simplified recursion.
Creates appropriate node types and relationships based on discovered data. REFACTORED: Simplified logic with complete forensic audit trail.
""" """
is_ip = _is_valid_ip(target) is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN target_type = NodeType.IP if is_ip else NodeType.DOMAIN
print(f"Querying {len(self.providers)} providers for {target_type.value}: {target}") print(f"Querying providers for {target_type.value}: {target} at depth {current_depth}")
# Initialize node and provider states
self.graph.add_node(target, target_type)
self._initialize_provider_states(target)
new_targets = set() new_targets = set()
all_relationships = [] target_metadata = defaultdict(lambda: defaultdict(list))
if not self.providers or self.stop_event.is_set(): # Determine eligible providers for this target
eligible_providers = self._get_eligible_providers(target, is_ip)
if not eligible_providers:
self._log_no_eligible_providers(target, is_ip)
return new_targets return new_targets
with ThreadPoolExecutor(max_workers=len(self.providers)) as provider_executor: # Query each eligible provider with forensic tracking
with ThreadPoolExecutor(max_workers=len(eligible_providers)) as provider_executor:
future_to_provider = { future_to_provider = {
provider_executor.submit( provider_executor.submit(self._query_single_provider_forensic, provider, target, is_ip, current_depth): provider
self._safe_provider_query, provider, target, is_ip for provider in eligible_providers
): provider
for provider in self.providers
} }
for future in as_completed(future_to_provider): for future in as_completed(future_to_provider):
@ -345,60 +373,139 @@ class Scanner:
provider = future_to_provider[future] provider = future_to_provider[future]
try: try:
relationships = future.result() provider_results = future.result()
print(f"Provider {provider.get_name()} returned {len(relationships)} relationships for {target}") if provider_results:
for rel in relationships: discovered_targets = self._process_provider_results_forensic(
source, rel_target, rel_type, confidence, raw_data = rel target, provider, provider_results, target_metadata, current_depth
enhanced_rel = (source, rel_target, rel_type, confidence, raw_data, provider.get_name()) )
all_relationships.append(enhanced_rel) new_targets.update(discovered_targets)
except (Exception, CancelledError) as e: except (Exception, CancelledError) as e:
print(f"Provider {provider.get_name()} failed for {target}: {e}") self._log_provider_error(target, provider.get_name(), str(e))
# NEW Step 2: Group all targets by type and identify large entities # Update node with collected metadata
discovered_targets_by_type = defaultdict(set) if target_metadata[target]:
for _, rel_target, _, _, _, _ in all_relationships: self.graph.add_node(target, target_type, metadata=dict(target_metadata[target]))
if _is_valid_domain(rel_target):
discovered_targets_by_type[NodeType.DOMAIN].add(rel_target)
elif _is_valid_ip(rel_target):
discovered_targets_by_type[NodeType.IP].add(rel_target)
targets_to_skip = set() return new_targets
for node_type, targets in discovered_targets_by_type.items():
if len(targets) > self.config.large_entity_threshold:
print(f"Large number of {node_type.value}s ({len(targets)}) found for {target}. Creating a large entity node.")
first_rel = next((r for r in all_relationships if r[1] in targets), None)
if first_rel:
self._handle_large_entity(target, list(targets), first_rel[2], first_rel[5])
targets_to_skip.update(targets)
# Step 3: Process all relationships to create/update nodes and edges def _get_eligible_providers(self, target: str, is_ip: bool) -> List:
target_metadata = defaultdict(lambda: defaultdict(list)) """Get providers eligible for querying this target."""
eligible = []
target_key = 'ips' if is_ip else 'domains'
for provider in self.providers:
provider_name = provider.get_name()
if provider_name in self.provider_eligibility:
if self.provider_eligibility[provider_name][target_key]:
# Check if we already queried this provider for this target
if not self._already_queried_provider(target, provider_name):
eligible.append(provider)
else:
print(f"Skipping {provider_name} for {target} - already queried")
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""Check if we already queried a provider for a target."""
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node()
return False
node_data = self.graph.graph.nodes[target]
provider_states = node_data.get('metadata', {}).get('provider_states', {})
return provider_name in provider_states
def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> List:
"""Query a single provider with complete forensic logging."""
provider_name = provider.get_name()
start_time = datetime.now(timezone.utc)
print(f"Querying {provider_name} for {target}")
# Log attempt
self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}")
try:
# Perform the query
if is_ip:
results = provider.query_ip(target)
else:
results = provider.query_domain(target)
# Track successful state
self._update_provider_state(target, provider_name, 'success', len(results), None, start_time)
print(f"{provider_name} returned {len(results)} results for {target}")
return results
except Exception as e:
# Track failed state
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
print(f"{provider_name} failed for {target}: {e}")
raise
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: str, start_time: datetime) -> None:
"""Update provider state in node metadata for forensic tracking."""
if not self.graph.graph.has_node(target): # Fix: Use .graph.has_node()
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
node_data['metadata']['provider_states'][provider_name] = {
'status': status,
'timestamp': start_time.isoformat(),
'results_count': results_count,
'error': error,
'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
}
# Log to forensic trail
self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)")
def _process_provider_results_forensic(self, target: str, provider, results: List,
target_metadata: Dict, current_depth: int) -> Set[str]:
"""Process provider results with large entity detection and forensic logging."""
provider_name = provider.get_name()
discovered_targets = set()
# Check for large entity threshold per provider
if len(results) > self.config.large_entity_threshold:
print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}")
self._create_large_entity(target, provider_name, results, current_depth)
# Large entities block recursion - return empty set
return discovered_targets
# Process each relationship
dns_records_to_create = {} dns_records_to_create = {}
for source, rel_target, rel_type, confidence, raw_data, provider_name in all_relationships: for source, rel_target, rel_type, confidence, raw_data in results:
if self.stop_event.is_set(): if self.stop_event.is_set():
break break
# Special handling for crt.sh to distribute certificate metadata # Enhanced forensic logging for each relationship
if provider_name == 'crtsh' and 'domain_certificates' in raw_data: self.logger.log_relationship_discovery(
domain_certs = raw_data.get('domain_certificates', {}) source_node=source,
for cert_domain, cert_summary in domain_certs.items(): target_node=rel_target,
if _is_valid_domain(cert_domain) and cert_domain not in targets_to_skip: relationship_type=rel_type.relationship_name,
self.graph.add_node(cert_domain, NodeType.DOMAIN, metadata={'certificate_data': cert_summary}) confidence_score=confidence,
provider=provider_name,
raw_data=raw_data,
discovery_method=f"{provider_name}_query_depth_{current_depth}"
)
# General metadata collection # Collect metadata for source node
self._collect_node_metadata(source, provider_name, rel_type, rel_target, raw_data, target_metadata[source]) self._collect_node_metadata_forensic(source, provider_name, rel_type, rel_target, raw_data, target_metadata[source])
# Add nodes and edges to the graph
if rel_target in targets_to_skip:
continue
# Add nodes and edges based on target type
if _is_valid_ip(rel_target): if _is_valid_ip(rel_target):
self.graph.add_node(rel_target, NodeType.IP) self.graph.add_node(rel_target, NodeType.IP)
if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {rel_target} ({rel_type.relationship_name})") print(f"Added IP relationship: {source} -> {rel_target} ({rel_type.relationship_name})")
if rel_type in [RelationshipType.A_RECORD, RelationshipType.AAAA_RECORD]: discovered_targets.add(rel_target)
new_targets.add(rel_target)
elif rel_target.startswith('AS') and rel_target[2:].isdigit(): elif rel_target.startswith('AS') and rel_target[2:].isdigit():
self.graph.add_node(rel_target, NodeType.ASN) self.graph.add_node(rel_target, NodeType.ASN)
@ -409,71 +516,71 @@ class Scanner:
self.graph.add_node(rel_target, NodeType.DOMAIN) self.graph.add_node(rel_target, NodeType.DOMAIN)
if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data):
print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})") print(f"Added domain relationship: {source} -> {rel_target} ({rel_type.relationship_name})")
discovered_targets.add(rel_target)
recurse_types = [ # *** NEW: Enrich the newly discovered domain ***
RelationshipType.CNAME_RECORD, RelationshipType.MX_RECORD, self._collect_node_metadata_forensic(rel_target, provider_name, rel_type, source, raw_data, target_metadata[rel_target])
RelationshipType.SAN_CERTIFICATE, RelationshipType.NS_RECORD,
RelationshipType.PASSIVE_DNS
]
if rel_type in recurse_types:
new_targets.add(rel_target)
else: else:
# Handle DNS record content # Handle DNS record content
dns_record_types = [ self._handle_dns_record_content(source, rel_target, rel_type, confidence, raw_data, provider_name, dns_records_to_create)
RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD,
RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD,
RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD,
RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD,
RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD
]
if rel_type in dns_record_types:
record_type = rel_type.relationship_name.upper().replace('_RECORD', '')
record_content = rel_target.strip()
content_hash = hash(record_content) & 0x7FFFFFFF
dns_record_id = f"{record_type}:{content_hash}"
if dns_record_id not in dns_records_to_create: # Create DNS record nodes
dns_records_to_create[dns_record_id] = { self._create_dns_record_nodes(dns_records_to_create)
'content': record_content, 'type': record_type, 'domains': set(),
'raw_data': raw_data, 'provider_name': provider_name, 'confidence': confidence return discovered_targets
def _create_large_entity(self, source: str, provider_name: str, results: List, current_depth: int) -> None:
"""Create a large entity node for forensic tracking."""
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
# Extract targets from results
targets = [rel[1] for rel in results if len(rel) > 1]
# Determine node type
node_type = 'unknown'
if targets:
if _is_valid_domain(targets[0]):
node_type = 'domain'
elif _is_valid_ip(targets[0]):
node_type = 'ip'
# Create large entity metadata
metadata = {
'count': len(targets),
'nodes': targets,
'node_type': node_type,
'source_provider': provider_name,
'discovery_depth': current_depth,
'threshold_exceeded': self.config.large_entity_threshold,
'forensic_note': f'Large entity created due to {len(targets)} results from {provider_name}'
} }
dns_records_to_create[dns_record_id]['domains'].add(source)
# Step 4: Update the source node with its collected metadata # Create the node and edge
if target in target_metadata: self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, metadata=metadata)
self.graph.add_node(target, target_type, metadata=dict(target_metadata[target]))
# Step 5: Create DNS record nodes and edges # Use first result's relationship type for the edge
for dns_record_id, record_info in dns_records_to_create.items(): if results:
record_metadata = { rel_type = results[0][2]
'record_type': record_info['type'], 'content': record_info['content'], self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
'content_hash': dns_record_id.split(':')[1], {'large_entity_info': f'Contains {len(targets)} {node_type}s'})
'associated_domains': list(record_info['domains']),
'source_data': record_info['raw_data']
}
self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata)
for domain_name in record_info['domains']:
self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD,
record_info['confidence'], record_info['provider_name'],
record_info['raw_data'])
return new_targets # Forensic logging
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
def _collect_node_metadata(self, node_id: str, provider_name: str, rel_type: RelationshipType, print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}")
def _collect_node_metadata_forensic(self, node_id: str, provider_name: str, rel_type: RelationshipType,
target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None: target: str, raw_data: Dict[str, Any], metadata: Dict[str, Any]) -> None:
""" """Collect and organize metadata for forensic tracking with enhanced logging."""
Collect and organize metadata for a node based on provider responses.
""" # Log metadata collection
self.logger.logger.debug(f"Collecting metadata for {node_id} from {provider_name}: {rel_type.relationship_name}")
if provider_name == 'dns': if provider_name == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN') record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target) value = raw_data.get('value', target)
if record_type in ['TXT', 'SPF', 'CAA']:
dns_entry = f"{record_type}: {value}" dns_entry = f"{record_type}: {value}"
else:
dns_entry = f"{record_type}: {value}"
if dns_entry not in metadata.get('dns_records', []): if dns_entry not in metadata.get('dns_records', []):
metadata.setdefault('dns_records', []).append(dns_entry) metadata.setdefault('dns_records', []).append(dns_entry)
@ -486,24 +593,13 @@ class Scanner:
metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False) metadata['has_valid_cert'] = cert_summary.get('has_valid_cert', False)
if target not in metadata.get('related_domains_san', []): if target not in metadata.get('related_domains_san', []):
metadata.setdefault('related_domains_san', []).append(target) metadata.setdefault('related_domains_san', []).append(target)
shared_certs = raw_data.get('shared_certificates', [])
if shared_certs and 'shared_certificate_details' not in metadata:
metadata['shared_certificate_details'] = shared_certs
elif provider_name == 'shodan': elif provider_name == 'shodan':
for key, value in raw_data.items(): for key, value in raw_data.items():
if key not in metadata.get('shodan', {}) or not metadata.get('shodan', {}).get(key): if key not in metadata.get('shodan', {}) or not metadata.get('shodan', {}).get(key):
metadata.setdefault('shodan', {})[key] = value metadata.setdefault('shodan', {})[key] = value
elif provider_name == 'virustotal': # Track ASN data
for key, value in raw_data.items():
if key not in metadata.get('virustotal', {}) or not metadata.get('virustotal', {}).get(key):
metadata.setdefault('virustotal', {})[key] = value
if rel_type == RelationshipType.PASSIVE_DNS:
passive_entry = f"Passive DNS: {target}"
if passive_entry not in metadata.get('passive_dns', []):
metadata.setdefault('passive_dns', []).append(passive_entry)
if rel_type == RelationshipType.ASN_MEMBERSHIP: if rel_type == RelationshipType.ASN_MEMBERSHIP:
metadata['asn_data'] = { metadata['asn_data'] = {
'asn': target, 'asn': target,
@ -512,48 +608,82 @@ class Scanner:
'country': raw_data.get('country', '') 'country': raw_data.get('country', '')
} }
def _handle_large_entity(self, source: str, targets: list, rel_type: RelationshipType, provider_name: str): def _handle_dns_record_content(self, source: str, rel_target: str, rel_type: RelationshipType,
""" confidence: float, raw_data: Dict[str, Any], provider_name: str,
Handles the creation of a large entity node when a threshold is exceeded. dns_records: Dict) -> None:
""" """Handle DNS record content with forensic tracking."""
print(f"Large number of {rel_type.name} relationships for {source}. Creating a large entity node.") dns_record_types = [
entity_name = f"Large collection of {rel_type.name} for {source}" RelationshipType.TXT_RECORD, RelationshipType.SPF_RECORD,
node_type = 'unknown' RelationshipType.CAA_RECORD, RelationshipType.SRV_RECORD,
if targets: RelationshipType.DNSKEY_RECORD, RelationshipType.DS_RECORD,
if _is_valid_domain(targets[0]): RelationshipType.RRSIG_RECORD, RelationshipType.SSHFP_RECORD,
node_type = 'domain' RelationshipType.TLSA_RECORD, RelationshipType.NAPTR_RECORD
elif _is_valid_ip(targets[0]): ]
node_type = 'ip'
self.graph.add_node(entity_name, NodeType.LARGE_ENTITY, metadata={"count": len(targets), "nodes": targets, "node_type": node_type})
self.graph.add_edge(source, entity_name, rel_type, 0.9, provider_name, {"info": "Aggregated node"})
def _safe_provider_query(self, provider, target: str, is_ip: bool) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: if rel_type in dns_record_types:
"""Safely query a provider for a target with error handling.""" record_type = rel_type.relationship_name.upper().replace('_RECORD', '')
if self.stop_event.is_set(): record_content = rel_target.strip()
return [] content_hash = hash(record_content) & 0x7FFFFFFF
try: dns_record_id = f"{record_type}:{content_hash}"
if is_ip:
return provider.query_ip(target) if dns_record_id not in dns_records:
else: dns_records[dns_record_id] = {
return provider.query_domain(target) 'content': record_content,
except Exception as e: 'type': record_type,
print(f"Provider {provider.get_name()} query failed for {target}: {e}") 'domains': set(),
return [] 'raw_data': raw_data,
'provider_name': provider_name,
'confidence': confidence
}
dns_records[dns_record_id]['domains'].add(source)
def _create_dns_record_nodes(self, dns_records: Dict) -> None:
"""Create DNS record nodes with forensic metadata."""
for dns_record_id, record_info in dns_records.items():
record_metadata = {
'record_type': record_info['type'],
'content': record_info['content'],
'content_hash': dns_record_id.split(':')[1],
'associated_domains': list(record_info['domains']),
'source_data': record_info['raw_data'],
'forensic_note': f"DNS record created from {record_info['provider_name']} query"
}
self.graph.add_node(dns_record_id, NodeType.DNS_RECORD, metadata=record_metadata)
for domain_name in record_info['domains']:
self.graph.add_edge(domain_name, dns_record_id, RelationshipType.DNS_RECORD,
record_info['confidence'], record_info['provider_name'],
record_info['raw_data'])
# Forensic logging for DNS record creation
self.logger.logger.info(f"DNS record node created: {dns_record_id} for {len(record_info['domains'])} domains")
def _log_target_processing_error(self, target: str, error: str) -> None:
"""Log target processing errors for forensic trail."""
self.logger.logger.error(f"Target processing failed for {target}: {error}")
def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
"""Log provider query errors for forensic trail."""
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _log_no_eligible_providers(self, target: str, is_ip: bool) -> None:
"""Log when no providers are eligible for a target."""
target_type = 'IP' if is_ip else 'domain'
self.logger.logger.warning(f"No eligible providers for {target_type}: {target}")
def stop_scan(self) -> bool: def stop_scan(self) -> bool:
""" """Request immediate scan termination with forensic logging."""
Request immediate scan termination.
Acts on the thread's liveness, not just the 'RUNNING' status.
"""
try: try:
if not self.scan_thread or not self.scan_thread.is_alive(): if not self.scan_thread or not self.scan_thread.is_alive():
print("No active scan thread to stop.") print("No active scan thread to stop.")
# Cleanup state if inconsistent
if self.status == ScanStatus.RUNNING: if self.status == ScanStatus.RUNNING:
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
return False return False
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
self.logger.logger.info("Scan termination requested by user")
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
self.stop_event.set() self.stop_event.set()
@ -563,22 +693,15 @@ class Scanner:
print("Termination signal sent. The scan thread will stop shortly.") print("Termination signal sent. The scan thread will stop shortly.")
return True return True
except Exception as e: except Exception as e:
print(f"ERROR: Exception in stop_scan: {e}") print(f"ERROR: Exception in stop_scan: {e}")
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc() traceback.print_exc()
return False return False
def _force_stop_completion(self):
"""Force completion of stop operation after timeout."""
if self.status == ScanStatus.RUNNING:
print("Forcing scan termination after timeout")
self.status = ScanStatus.STOPPED
self.logger.log_scan_complete()
def get_scan_status(self) -> Dict[str, Any]: def get_scan_status(self) -> Dict[str, Any]:
""" """Get current scan status with forensic information."""
Get current scan status and progress.
"""
try: try:
return { return {
'status': self.status, 'status': self.status,
@ -615,27 +738,25 @@ class Scanner:
return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100) return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100)
def get_graph_data(self) -> Dict[str, Any]: def get_graph_data(self) -> Dict[str, Any]:
""" """Get current graph data for visualization."""
Get current graph data for visualization.
"""
return self.graph.get_graph_data() return self.graph.get_graph_data()
def export_results(self) -> Dict[str, Any]: def export_results(self) -> Dict[str, Any]:
""" """Export complete scan results with forensic audit trail."""
Export complete scan results including graph and audit trail.
"""
graph_data = self.graph.export_json() graph_data = self.graph.export_json()
audit_trail = self.logger.export_audit_trail() audit_trail = self.logger.export_audit_trail()
provider_stats = {} provider_stats = {}
for provider in self.providers: for provider in self.providers:
provider_stats[provider.get_name()] = provider.get_statistics() provider_stats[provider.get_name()] = provider.get_statistics()
export_data = { export_data = {
'scan_metadata': { 'scan_metadata': {
'target_domain': self.current_target, 'target_domain': self.current_target,
'max_depth': self.max_depth, 'max_depth': self.max_depth,
'final_status': self.status, 'final_status': self.status,
'total_indicators_processed': self.indicators_processed, 'total_indicators_processed': self.indicators_processed,
'enabled_providers': list(provider_stats.keys()) 'enabled_providers': list(provider_stats.keys()),
'forensic_note': 'Refactored scanner with simplified recursion and forensic tracking'
}, },
'graph_data': graph_data, 'graph_data': graph_data,
'forensic_audit': audit_trail, 'forensic_audit': audit_trail,
@ -645,9 +766,7 @@ class Scanner:
return export_data return export_data
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]: def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
""" """Get statistics for all providers with forensic information."""
Get statistics for all providers.
"""
stats = {} stats = {}
for provider in self.providers: for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics() stats[provider.get_name()] = provider.get_statistics()

View File

@ -478,34 +478,28 @@ class CrtShProvider(BaseProvider):
common_name = cert_data.get('common_name', '') common_name = cert_data.get('common_name', '')
if common_name: if common_name:
cleaned_cn = self._clean_domain_name(common_name) cleaned_cn = self._clean_domain_name(common_name)
if cleaned_cn and _is_valid_domain(cleaned_cn): if cleaned_cn:
domains.add(cleaned_cn) domains.update(cleaned_cn)
# Extract from name_value field (contains SANs) # Extract from name_value field (contains SANs)
name_value = cert_data.get('name_value', '') name_value = cert_data.get('name_value', '')
if name_value: if name_value:
# Split by newlines and clean each domain # Split by newlines and clean each domain
for line in name_value.split('\n'): for line in name_value.split('\n'):
cleaned_domain = self._clean_domain_name(line.strip()) cleaned_domains = self._clean_domain_name(line.strip())
if cleaned_domain and _is_valid_domain(cleaned_domain): if cleaned_domains:
domains.add(cleaned_domain) domains.update(cleaned_domains)
return domains return domains
def _clean_domain_name(self, domain_name: str) -> str: def _clean_domain_name(self, domain_name: str) -> List[str]:
""" """
Clean and normalize domain name from certificate data. Clean and normalize domain name from certificate data.
Now returns a list to handle wildcards correctly.
Args:
domain_name: Raw domain name from certificate
Returns:
Cleaned domain name or empty string if invalid
""" """
if not domain_name: if not domain_name:
return "" return []
# Remove common prefixes and clean up
domain = domain_name.strip().lower() domain = domain_name.strip().lower()
# Remove protocol if present # Remove protocol if present
@ -521,14 +515,19 @@ class CrtShProvider(BaseProvider):
domain = domain.split(':', 1)[0] domain = domain.split(':', 1)[0]
# Handle wildcard domains # Handle wildcard domains
cleaned_domains = []
if domain.startswith('*.'): if domain.startswith('*.'):
domain = domain[2:] # Add both the wildcard and the base domain
cleaned_domains.append(domain)
cleaned_domains.append(domain[2:])
else:
cleaned_domains.append(domain)
# Remove any remaining invalid characters # Remove any remaining invalid characters and validate
domain = re.sub(r'[^\w\-\.]', '', domain) final_domains = []
for d in cleaned_domains:
d = re.sub(r'[^\w\-\.]', '', d)
if d and not d.startswith(('.', '-')) and not d.endswith(('.', '-')):
final_domains.append(d)
# Ensure it's not empty and doesn't start/end with dots or hyphens return [d for d in final_domains if _is_valid_domain(d)]
if domain and not domain.startswith(('.', '-')) and not domain.endswith(('.', '-')):
return domain
return ""