dnsrecon/core/scanner.py
2025-09-16 00:57:24 +02:00

1029 lines
47 KiB
Python

# dnsrecon-reduced/core/scanner.py
import threading
import traceback
import time
import os
import importlib
import redis
from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future
from collections import defaultdict
from queue import PriorityQueue
from datetime import datetime, timezone
from core.graph_manager import GraphManager, NodeType
from core.logger import get_forensic_logger, new_session
from utils.helpers import _is_valid_ip, _is_valid_domain
from providers.base_provider import BaseProvider
from core.rate_limiter import GlobalRateLimiter
class ScanStatus:
"""Enumeration of scan statuses."""
IDLE = "idle"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
STOPPED = "stopped"
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
"""
def __init__(self, session_config=None):
"""Initialize scanner with session-specific configuration."""
print("Initializing Scanner instance...")
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
self.config = session_config
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event()
self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
# **NEW**: Track currently processing tasks to prevent processing after stop
self.currently_processing = set()
self.processing_lock = threading.Lock()
# Scanning progress tracking
self.total_indicators_found = 0
self.indicators_processed = 0
self.indicators_completed = 0
self.tasks_re_enqueued = 0
self.total_tasks_ever_enqueued = 0
self.current_indicator = ""
# Concurrent processing configuration
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Initialize providers with session config
print("Calling _initialize_providers with session config...")
self._initialize_providers()
# Initialize logger
print("Initializing forensic logger...")
self.logger = get_forensic_logger()
# Initialize global rate limiter
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
print("Scanner initialization complete")
except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc()
raise
def _is_stop_requested(self) -> bool:
"""
Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries.
"""
# Check local threading event first (fastest)
if self.stop_event.is_set():
return True
# Check Redis-based stop signal if session ID is available
if self.session_id:
try:
from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id)
except Exception as e:
print(f"Error checking Redis stop signal: {e}")
# Fall back to local event
return self.stop_event.is_set()
return False
def _set_stop_signal(self) -> None:
"""
Set stop signal both locally and in Redis.
"""
# Set local event
self.stop_event.set()
# Set Redis signal if session ID is available
if self.session_id:
try:
from core.session_manager import session_manager
session_manager.set_stop_signal(self.session_id)
except Exception as e:
print(f"Error setting Redis stop signal: {e}")
def __getstate__(self):
"""Prepare object for pickling by excluding unpicklable attributes."""
state = self.__dict__.copy()
# Remove unpicklable threading objects
unpicklable_attrs = [
'stop_event',
'scan_thread',
'executor',
'processing_lock',
'task_queue',
'rate_limiter',
'logger'
]
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
# Handle providers separately to ensure they're picklable
if 'providers' in state:
for provider in state['providers']:
if hasattr(provider, '_stop_event'):
provider._stop_event = None
return state
def __setstate__(self, state):
"""Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# Reconstruct threading objects
self.stop_event = threading.Event()
self.scan_thread = None
self.executor = None
self.processing_lock = threading.Lock()
self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger()
if not hasattr(self, 'providers') or not self.providers:
print("Providers not found after loading session, re-initializing...")
self._initialize_providers()
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
# Re-set stop events for providers
if hasattr(self, 'providers'):
for provider in self.providers:
if hasattr(provider, 'set_stop_event'):
provider.set_stop_event(self.stop_event)
def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration."""
self.providers = []
print("Initializing providers with session config...")
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name()
if self.config.is_provider_enabled(provider_name):
if provider.is_available():
provider.set_stop_event(self.stop_event)
self.providers.append(provider)
print(f"{provider.get_display_name()} provider initialized successfully for session")
else:
print(f"{provider.get_display_name()} provider is not available")
except Exception as e:
print(f"✗ Failed to initialize provider from {filename}: {e}")
traceback.print_exc()
print(f"Initialized {len(self.providers)} providers for session")
def update_session_config(self, new_config) -> None:
"""Update session configuration and reinitialize providers."""
print("Updating session configuration...")
self.config = new_config
self.max_workers = self.config.max_concurrent_requests
self._initialize_providers()
print("Session configuration updated")
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
"""Start a new reconnaissance scan with proper cleanup of previous scans."""
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Session ID: {self.session_id}")
print(f"Initial scanner status: {self.status}")
self.total_tasks_ever_enqueued = 0
# **IMPROVED**: More aggressive cleanup of previous scan
if self.scan_thread and self.scan_thread.is_alive():
print("A previous scan thread is still alive. Forcing termination...")
# Set stop signals immediately
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# Clear all processing state
with self.processing_lock:
self.currently_processing.clear()
self.task_queue = PriorityQueue()
# Shutdown executor aggressively
if self.executor:
print("Shutting down executor forcefully...")
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None
# Wait for thread termination with shorter timeout
print("Waiting for previous scan thread to terminate...")
self.scan_thread.join(5.0) # Reduced from 10 seconds
if self.scan_thread.is_alive():
print("WARNING: Previous scan thread is still alive after 5 seconds")
# Continue anyway, but log the issue
self.logger.logger.warning("Previous scan thread failed to terminate cleanly")
# Reset state for new scan with proper forensic logging
print("Resetting scanner state for new scan...")
self.status = ScanStatus.IDLE
self.stop_event.clear()
# **NEW**: Clear Redis stop signal explicitly
if self.session_id:
from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
with self.processing_lock:
self.currently_processing.clear()
self.task_queue = PriorityQueue()
self.target_retries.clear()
self.scan_failed_due_to_retries = False
# Update session state immediately for GUI feedback
self._update_session_state()
print("Scanner state reset complete.")
try:
if not hasattr(self, 'providers') or not self.providers:
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
return False
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
if clear_graph:
self.graph.clear()
if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
print(f"Forcing rescan of {force_rescan_target}, clearing provider states.")
node_data = self.graph.graph.nodes[force_rescan_target]
if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
self.current_target = target.lower().strip()
self.max_depth = max_depth
self.current_depth = 0
self.total_indicators_found = 0
self.indicators_processed = 0
self.indicators_completed = 0
self.tasks_re_enqueued = 0
self.current_indicator = self.current_target
# Update GUI with scan preparation state
self._update_session_state()
# Start new forensic session
print(f"Starting new forensic session for scanner {id(self)}...")
self.logger = new_session()
# Start scan in a separate thread
print(f"Starting scan thread for scanner {id(self)}...")
self.scan_thread = threading.Thread(
target=self._execute_scan,
args=(self.current_target, max_depth),
daemon=True
)
self.scan_thread.start()
print(f"=== SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===")
return True
except Exception as e:
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self._update_session_state()
return False
def _get_priority(self, provider_name):
rate_limit = self.config.get_rate_limit(provider_name)
if rate_limit > 90:
return 1 # Highest priority
elif rate_limit > 50:
return 2
else:
return 3 # Lowest priority
def _execute_scan(self, target: str, max_depth: int) -> None:
"""Execute the reconnaissance scan with proper termination handling."""
print(f"_execute_scan started for {target} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set()
# Initial task population for the main target
is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False)
for provider in initial_providers:
provider_name = provider.get_name()
self.task_queue.put((self._get_priority(provider_name), (provider_name, target, 0)))
self.total_tasks_ever_enqueued += 1
try:
self.status = ScanStatus.RUNNING
self._update_session_state()
enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target, max_depth, enabled_providers)
# Determine initial node type
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, node_type)
self._initialize_provider_states(target)
# Better termination checking in main loop
while not self.task_queue.empty() and not self._is_stop_requested():
try:
priority, (provider_name, target_item, depth) = self.task_queue.get()
except IndexError:
# Queue became empty during processing
break
task_tuple = (provider_name, target_item)
if task_tuple in processed_tasks:
continue
if depth > max_depth:
continue
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
self.task_queue.put((priority + 1, (provider_name, target_item, depth))) # Postpone
continue
with self.processing_lock:
if self._is_stop_requested():
print(f"Stop requested before processing {target_item}")
break
self.currently_processing.add(target_item)
try:
self.current_depth = depth
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested():
print(f"Stop requested during processing setup for {target_item}")
break
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider:
new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth)
if self._is_stop_requested():
print(f"Stop requested after querying providers for {target_item}")
break
if not success:
self.target_retries[task_tuple] += 1
if self.target_retries[task_tuple] <= self.config.max_retries_per_target:
print(f"Re-queueing task {task_tuple} (attempt {self.target_retries[task_tuple]})")
self.task_queue.put((priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
self.total_tasks_ever_enqueued += 1
else:
print(f"ERROR: Max retries exceeded for task {task_tuple}")
self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), "Max retries exceeded")
else:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
if not self._is_stop_requested():
all_new_targets = new_targets.union(large_entity_members)
for new_target in all_new_targets:
is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
for p_new in eligible_providers_new:
p_name_new = p_new.get_name()
if (p_name_new, new_target) not in processed_tasks:
new_depth = depth + 1 if new_target in new_targets else depth
self.task_queue.put((self._get_priority(p_name_new), (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1
finally:
with self.processing_lock:
self.currently_processing.discard(target_item)
if self._is_stop_requested():
print("Scan terminated due to stop request")
self.logger.logger.info("Scan terminated by user request")
elif self.task_queue.empty():
print("Scan completed - no more targets to process")
self.logger.logger.info("Scan completed - all targets processed")
except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
with self.processing_lock:
self.currently_processing.clear()
if self._is_stop_requested():
self.status = ScanStatus.STOPPED
elif self.scan_failed_due_to_retries:
self.status = ScanStatus.FAILED
else:
self.status = ScanStatus.COMPLETED
self._update_session_state()
self.logger.log_scan_complete()
if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None
stats = self.graph.get_statistics()
print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Tasks processed: {len(processed_tasks)}")
def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
if self._is_stop_requested():
print(f"Stop requested before querying {provider.get_name()} for {target}")
return set(), set(), False
is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
print(f"Querying {provider.get_name()} for {target_type.value}: {target} at depth {depth}")
self.graph.add_node(target, target_type)
self._initialize_provider_states(target)
new_targets = set()
large_entity_members = set()
node_attributes = defaultdict(lambda: defaultdict(list))
provider_successful = True
try:
provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth)
if provider_results is None:
provider_successful = False
elif not self._is_stop_requested():
discovered, is_large_entity = self._process_provider_results(
target, provider, provider_results, node_attributes, depth
)
if is_large_entity:
large_entity_members.update(discovered)
else:
new_targets.update(discovered)
else:
print(f"Stop requested after processing results from {provider.get_name()}")
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
if not self._is_stop_requested():
for node_id, attributes in node_attributes.items():
if self.graph.graph.has_node(node_id):
node_is_ip = _is_valid_ip(node_id)
node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN
self.graph.add_node(node_id, node_type_to_add, attributes=attributes)
return new_targets, large_entity_members, provider_successful
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
self.logger.logger.info("Scan termination requested by user")
# **IMPROVED**: More aggressive stop signal setting
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# **NEW**: Clear processing state immediately
with self.processing_lock:
currently_processing_copy = self.currently_processing.copy()
self.currently_processing.clear()
print(f"Cleared {len(currently_processing_copy)} currently processing targets: {currently_processing_copy}")
# **IMPROVED**: Clear task queue and log what was discarded
discarded_tasks = []
while not self.task_queue.empty():
discarded_tasks.append(self.task_queue.get())
self.task_queue = PriorityQueue()
print(f"Discarded {len(discarded_tasks)} pending tasks")
# **IMPROVED**: Aggressively shut down executor
if self.executor:
print("Shutting down executor with immediate cancellation...")
try:
# Cancel all pending futures
self.executor.shutdown(wait=False, cancel_futures=True)
print("Executor shutdown completed")
except Exception as e:
print(f"Error during executor shutdown: {e}")
# Immediately update GUI with stopped status
self._update_session_state()
print("Termination signals sent. The scan will stop as soon as possible.")
return True
except Exception as e:
print(f"ERROR: Exception in stop_scan: {e}")
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.
This ensures the web interface sees real-time updates.
"""
if self.session_id:
try:
from core.session_manager import session_manager
success = session_manager.update_session_scanner(self.session_id, self)
if not success:
print(f"WARNING: Failed to update session state for {self.session_id}")
except Exception as e:
print(f"ERROR: Failed to update session state: {e}")
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with processing information."""
try:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics(),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5]
}
except Exception as e:
print(f"ERROR: Exception in get_scan_status: {e}")
traceback.print_exc()
return {
'status': 'error',
'target_domain': None,
'current_depth': 0,
'max_depth': 0,
'current_indicator': '',
'indicators_processed': 0,
'indicators_completed': 0,
'tasks_re_enqueued': 0,
'progress_percentage': 0.0,
'enabled_providers': [],
'graph_statistics': {},
'task_queue_size': 0,
'currently_processing_count': 0,
'currently_processing': []
}
def _initialize_provider_states(self, target: str) -> None:
"""Initialize provider states for forensic tracking."""
if not self.graph.graph.has_node(target):
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
"""Get providers eligible for querying this target."""
if dns_only:
return [p for p in self.providers if p.get_name() == 'dns']
eligible = []
target_key = 'ips' if is_ip else 'domains'
for provider in self.providers:
if provider.get_eligibility().get(target_key):
if not self._already_queried_provider(target, provider.get_name()):
eligible.append(provider)
else:
print(f"Skipping {provider.get_name()} for {target} - already queried")
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""Check if we already successfully queried a provider for a target."""
if not self.graph.graph.has_node(target):
return False
node_data = self.graph.graph.nodes[target]
provider_states = node_data.get('metadata', {}).get('provider_states', {})
# A provider has been successfully queried if a state exists and its status is 'success'
provider_state = provider_states.get(provider_name)
return provider_state is not None and provider_state.get('status') == 'success'
def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> Optional[List]:
"""Query a single provider with stop signal checking."""
provider_name = provider.get_name()
start_time = datetime.now(timezone.utc)
if self._is_stop_requested():
print(f"Stop requested before querying {provider_name} for {target}")
return None
print(f"Querying {provider_name} for {target}")
self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}")
try:
if is_ip:
results = provider.query_ip(target)
else:
results = provider.query_domain(target)
if self._is_stop_requested():
print(f"Stop requested after querying {provider_name} for {target}")
return None
self._update_provider_state(target, provider_name, 'success', len(results), None, start_time)
print(f"{provider_name} returned {len(results)} results for {target}")
return results
except Exception as e:
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
print(f"{provider_name} failed for {target}: {e}")
return None
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None:
"""Update provider state in node metadata for forensic tracking."""
if not self.graph.graph.has_node(target):
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
node_data['metadata']['provider_states'][provider_name] = {
'status': status,
'timestamp': start_time.isoformat(),
'results_count': results_count,
'error': error,
'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
}
self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)")
def _process_provider_results(self, target: str, provider, results: List,
node_attributes: Dict, current_depth: int) -> Tuple[Set[str], bool]:
"""Process provider results, returns (discovered_targets, is_large_entity)."""
provider_name = provider.get_name()
discovered_targets = set()
if self._is_stop_requested():
print(f"Stop requested before processing results from {provider_name} for {target}")
return discovered_targets, False
if len(results) > self.config.large_entity_threshold:
print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}")
members = self._create_large_entity(target, provider_name, results, current_depth)
return members, True
for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results):
if i % 5 == 0 and self._is_stop_requested(): # Check more frequently
print(f"Stop requested while processing results from {provider_name} for {target}")
break
self.logger.log_relationship_discovery(
source_node=source,
target_node=rel_target,
relationship_type=rel_type,
confidence_score=confidence,
provider=provider_name,
raw_data=raw_data,
discovery_method=f"{provider_name}_query_depth_{current_depth}"
)
# Collect attributes for the source node
self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source])
# If the relationship is asn_membership, collect attributes for the target ASN node
if rel_type == 'asn_membership':
self._collect_node_attributes(rel_target, provider_name, rel_type, source, raw_data, node_attributes[rel_target])
if isinstance(rel_target, list):
# If the target is a list, iterate and process each item
for single_target in rel_target:
if _is_valid_ip(single_target):
self.graph.add_node(single_target, NodeType.IP)
if self.graph.add_edge(source, single_target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {single_target} ({rel_type})")
discovered_targets.add(single_target)
elif _is_valid_domain(single_target):
self.graph.add_node(single_target, NodeType.DOMAIN)
if self.graph.add_edge(source, single_target, rel_type, confidence, provider_name, raw_data):
print(f"Added domain relationship: {source} -> {single_target} ({rel_type})")
discovered_targets.add(single_target)
self._collect_node_attributes(single_target, provider_name, rel_type, source, raw_data, node_attributes[single_target])
elif _is_valid_ip(rel_target):
self.graph.add_node(rel_target, NodeType.IP)
if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data):
print(f"Added IP relationship: {source} -> {rel_target} ({rel_type})")
discovered_targets.add(rel_target)
elif rel_target.startswith('AS') and rel_target[2:].isdigit():
self.graph.add_node(rel_target, NodeType.ASN)
if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data):
print(f"Added ASN relationship: {source} -> {rel_target} ({rel_type})")
elif _is_valid_domain(rel_target):
self.graph.add_node(rel_target, NodeType.DOMAIN)
if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data):
print(f"Added domain relationship: {source} -> {rel_target} ({rel_type})")
discovered_targets.add(rel_target)
self._collect_node_attributes(rel_target, provider_name, rel_type, source, raw_data, node_attributes[rel_target])
else:
self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source])
return discovered_targets, False
def _create_large_entity(self, source: str, provider_name: str, results: List, current_depth: int) -> Set[str]:
"""Create a large entity node and returns the members for DNS processing."""
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
targets = [rel[1] for rel in results if len(rel) > 1]
node_type = 'unknown'
if targets:
if _is_valid_domain(targets[0]):
node_type = 'domain'
elif _is_valid_ip(targets[0]):
node_type = 'ip'
# We still create the nodes so they exist in the graph, they are just not processed for edges yet.
for target in targets:
self.graph.add_node(target, NodeType.DOMAIN if node_type == 'domain' else NodeType.IP)
attributes = {
'count': len(targets),
'nodes': targets,
'node_type': node_type,
'source_provider': provider_name,
'discovery_depth': current_depth,
'threshold_exceeded': self.config.large_entity_threshold,
}
description = f'Large entity created due to {len(targets)} results from {provider_name}'
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes, description=description)
if results:
rel_type = results[0][2]
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
{'large_entity_info': f'Contains {len(targets)} {node_type}s'})
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}")
return set(targets)
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
"""
Extracts a node from a large entity, re-creates its original edge, and
re-queues it for full scanning.
"""
if not self.graph.graph.has_node(large_entity_id):
print(f"ERROR: Large entity {large_entity_id} not found.")
return False
# 1. Get the original source node that discovered the large entity
predecessors = list(self.graph.graph.predecessors(large_entity_id))
if not predecessors:
print(f"ERROR: No source node found for large entity {large_entity_id}.")
return False
source_node_id = predecessors[0]
# Get the original edge data to replicate it for the extracted node
original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id)
if not original_edge_data:
print(f"ERROR: Could not find original edge data from {source_node_id} to {large_entity_id}.")
return False
# 2. Modify the graph data structure first
success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract)
if not success:
print(f"ERROR: Node {node_id_to_extract} could not be removed from {large_entity_id}'s attributes.")
return False
# 3. Create the direct edge from the original source to the newly extracted node
print(f"Re-creating direct edge from {source_node_id} to extracted node {node_id_to_extract}")
self.graph.add_edge(
source_id=source_node_id,
target_id=node_id_to_extract,
relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'),
confidence_score=original_edge_data.get('confidence_score', 0.85), # Slightly lower confidence
source_provider=original_edge_data.get('source_provider', 'unknown'),
raw_data={'context': f'Extracted from large entity {large_entity_id}'}
)
# 4. Re-queue the extracted node for full processing by all eligible providers
print(f"Re-queueing extracted node {node_id_to_extract} for full reconnaissance...")
is_ip = _is_valid_ip(node_id_to_extract)
current_depth = self.graph.graph.nodes[large_entity_id].get('attributes', {}).get('discovery_depth', 0)
eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False)
for provider in eligible_providers:
provider_name = provider.get_name()
self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth)))
self.total_tasks_ever_enqueued += 1
# 5. If the scanner is not running, we need to kickstart it to process this one item.
if self.status != ScanStatus.RUNNING:
print("Scanner is idle. Starting a mini-scan to process the extracted node.")
self.status = ScanStatus.RUNNING
self._update_session_state()
if not self.scan_thread or not self.scan_thread.is_alive():
self.scan_thread = threading.Thread(
target=self._execute_scan,
args=(self.current_target, self.max_depth),
daemon=True
)
self.scan_thread.start()
print(f"Successfully extracted and re-queued {node_id_to_extract} from {large_entity_id}.")
return True
def _collect_node_attributes(self, node_id: str, provider_name: str, rel_type: str,
target: str, raw_data: Dict[str, Any], attributes: Dict[str, Any]) -> None:
"""Collect and organize attributes for a node."""
self.logger.logger.debug(f"Collecting attributes for {node_id} from {provider_name}: {rel_type}")
if provider_name == 'dns':
record_type = raw_data.get('query_type', 'UNKNOWN')
value = raw_data.get('value', target)
dns_entry = f"{record_type}: {value}"
if dns_entry not in attributes.get('dns_records', []):
attributes.setdefault('dns_records', []).append(dns_entry)
elif provider_name == 'crtsh':
if rel_type == "san_certificate":
domain_certs = raw_data.get('domain_certificates', {})
if node_id in domain_certs:
cert_summary = domain_certs[node_id]
attributes['certificates'] = cert_summary
if target not in attributes.get('related_domains_san', []):
attributes.setdefault('related_domains_san', []).append(target)
elif provider_name == 'shodan':
# This logic will now apply to the correct node (ASN or IP)
shodan_attributes = attributes.setdefault('shodan', {})
for key, value in raw_data.items():
if key not in shodan_attributes or not shodan_attributes.get(key):
shodan_attributes[key] = value
if _is_valid_ip(node_id):
if 'ports' in raw_data:
attributes['ports'] = raw_data['ports']
if 'os' in raw_data and raw_data['os']:
attributes['os'] = raw_data['os']
if rel_type == "asn_membership":
# This is the key change: these attributes are for the target (the ASN),
# not the source (the IP). We will add them to the ASN node later.
pass
record_type_name = rel_type
if record_type_name not in attributes:
attributes[record_type_name] = []
if isinstance(target, list):
attributes[record_type_name].extend(target)
else:
if target not in attributes[record_type_name]:
attributes[record_type_name].append(target)
def _log_target_processing_error(self, target: str, error: str) -> None:
"""Log target processing errors for forensic trail."""
self.logger.logger.error(f"Target processing failed for {target}: {error}")
def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
"""Log provider query errors for forensic trail."""
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _log_no_eligible_providers(self, target: str, is_ip: bool) -> None:
"""Log when no providers are eligible for a target."""
target_type = 'IP' if is_ip else 'domain'
self.logger.logger.warning(f"No eligible providers for {target_type}: {target}")
def _calculate_progress(self) -> float:
"""Calculate scan progress percentage based on task completion."""
if self.total_tasks_ever_enqueued == 0:
return 0.0
return min(100.0, (self.indicators_completed / self.total_tasks_ever_enqueued) * 100)
def get_graph_data(self) -> Dict[str, Any]:
"""Get current graph data for visualization."""
return self.graph.get_graph_data()
def export_results(self) -> Dict[str, Any]:
"""Export complete scan results with forensic audit trail."""
graph_data = self.graph.export_json()
audit_trail = self.logger.export_audit_trail()
provider_stats = {}
for provider in self.providers:
provider_stats[provider.get_name()] = provider.get_statistics()
export_data = {
'scan_metadata': {
'target_domain': self.current_target,
'max_depth': self.max_depth,
'final_status': self.status,
'total_indicators_processed': self.indicators_processed,
'enabled_providers': list(provider_stats.keys()),
'session_id': self.session_id
},
'graph_data': graph_data,
'forensic_audit': audit_trail,
'provider_statistics': provider_stats,
'scan_summary': self.logger.get_forensic_summary()
}
return export_data
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
"""Get statistics for all providers with forensic information."""
stats = {}
for provider in self.providers:
stats[provider.get_name()] = provider.get_statistics()
return stats
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
"""Get information about all available providers."""
info = {}
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
# Instantiate to get metadata, even if not fully configured
temp_provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = temp_provider.get_name()
# Find the actual provider instance if it exists, to get live stats
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
info[provider_name] = {
'display_name': temp_provider.get_display_name(),
'requires_api_key': temp_provider.requires_api_key(),
'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
'enabled': self.config.is_provider_enabled(provider_name),
'rate_limit': self.config.get_rate_limit(provider_name),
}
except Exception as e:
print(f"✗ Failed to get info for provider from {filename}: {e}")
traceback.print_exc()
return info