742 lines
33 KiB
Python
742 lines
33 KiB
Python
# dnsrecon/core/scanner.py
|
|
|
|
import threading
|
|
import traceback
|
|
import time
|
|
import os
|
|
import importlib
|
|
from typing import List, Set, Dict, Any, Tuple
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future
|
|
from collections import defaultdict, deque
|
|
from datetime import datetime, timezone
|
|
|
|
from core.graph_manager import GraphManager, NodeType
|
|
from core.logger import get_forensic_logger, new_session
|
|
from core.task_manager import TaskManager, TaskType, ReconTask
|
|
from utils.helpers import _is_valid_ip, _is_valid_domain
|
|
from providers.base_provider import BaseProvider
|
|
|
|
|
|
class ScanStatus:
|
|
"""Enumeration of scan statuses."""
|
|
IDLE = "idle"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class Scanner:
|
|
"""
|
|
Enhanced scanning orchestrator for DNSRecon passive reconnaissance.
|
|
Now uses task-based completion model with comprehensive retry logic.
|
|
"""
|
|
|
|
def __init__(self, session_config=None):
|
|
"""Initialize scanner with session-specific configuration and task management."""
|
|
print("Initializing Enhanced Scanner instance...")
|
|
|
|
try:
|
|
# Use provided session config or create default
|
|
if session_config is None:
|
|
from core.session_config import create_session_config
|
|
session_config = create_session_config()
|
|
|
|
self.config = session_config
|
|
self.graph = GraphManager()
|
|
self.providers = []
|
|
self.status = ScanStatus.IDLE
|
|
self.current_target = None
|
|
self.current_depth = 0
|
|
self.max_depth = 2
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
self.session_id = None # Will be set by session manager
|
|
self.current_scan_id = None # Track current scan ID
|
|
|
|
# Task-based execution components
|
|
self.task_manager = None # Will be initialized when needed
|
|
self.max_workers = self.config.max_concurrent_requests
|
|
|
|
# Enhanced progress tracking
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.current_indicator = ""
|
|
self.scan_start_time = None
|
|
self.scan_end_time = None
|
|
|
|
# Initialize providers with session config
|
|
print("Calling _initialize_providers with session config...")
|
|
self._initialize_providers()
|
|
|
|
# Initialize logger
|
|
print("Initializing forensic logger...")
|
|
self.logger = get_forensic_logger()
|
|
|
|
print("Enhanced Scanner initialization complete")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Enhanced Scanner initialization failed: {e}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
def __getstate__(self):
|
|
"""Prepare object for pickling by excluding unpicklable attributes."""
|
|
state = self.__dict__.copy()
|
|
|
|
# Remove unpicklable threading objects
|
|
unpicklable_attrs = [
|
|
'stop_event',
|
|
'scan_thread',
|
|
'task_manager'
|
|
]
|
|
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
|
|
# Handle providers separately to ensure they're picklable
|
|
if 'providers' in state:
|
|
# The providers should be picklable now, but let's ensure clean state
|
|
for provider in state['providers']:
|
|
if hasattr(provider, '_stop_event'):
|
|
provider._stop_event = None
|
|
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore object after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
|
|
# Reconstruct threading objects
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
self.task_manager = None
|
|
|
|
# Re-set stop events for providers
|
|
if hasattr(self, 'providers'):
|
|
for provider in self.providers:
|
|
if hasattr(provider, 'set_stop_event'):
|
|
provider.set_stop_event(self.stop_event)
|
|
|
|
def _is_stop_requested(self) -> bool:
|
|
"""
|
|
Enhanced stop signal checking that handles both local and Redis-based signals.
|
|
"""
|
|
# Check local threading event first (fastest)
|
|
if self.stop_event.is_set():
|
|
return True
|
|
|
|
# Check Redis-based stop signal if session ID is available
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
return session_manager.is_stop_requested(self.session_id)
|
|
except Exception as e:
|
|
print(f"Error checking Redis stop signal: {e}")
|
|
# Fall back to local event
|
|
return self.stop_event.is_set()
|
|
|
|
return False
|
|
|
|
def _set_stop_signal(self) -> None:
|
|
"""
|
|
Set stop signal both locally and in Redis.
|
|
"""
|
|
# Set local event
|
|
self.stop_event.set()
|
|
|
|
# Set Redis signal if session ID is available
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
session_manager.set_stop_signal(self.session_id)
|
|
except Exception as e:
|
|
print(f"Error setting Redis stop signal: {e}")
|
|
|
|
def _initialize_providers(self) -> None:
|
|
"""Initialize all available providers based on session configuration."""
|
|
self.providers = []
|
|
print("Initializing providers with session config...")
|
|
|
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
|
print(f"Looking for providers in: {provider_dir}")
|
|
|
|
if not os.path.exists(provider_dir):
|
|
print(f"ERROR: Provider directory does not exist: {provider_dir}")
|
|
return
|
|
|
|
provider_files = [f for f in os.listdir(provider_dir) if f.endswith('_provider.py') and not f.startswith('base')]
|
|
print(f"Found provider files: {provider_files}")
|
|
|
|
for filename in provider_files:
|
|
module_name = f"providers.{filename[:-3]}"
|
|
print(f"Attempting to load module: {module_name}")
|
|
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
print(f" ✓ Module {module_name} loaded successfully")
|
|
|
|
# Find provider classes in the module
|
|
provider_classes_found = []
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
|
provider_classes_found.append((attribute_name, attribute))
|
|
|
|
print(f" Found provider classes: {[name for name, _ in provider_classes_found]}")
|
|
|
|
for class_name, provider_class in provider_classes_found:
|
|
try:
|
|
# Create temporary instance to get provider name
|
|
temp_provider = provider_class(session_config=self.config)
|
|
provider_name = temp_provider.get_name()
|
|
print(f" Provider {class_name} -> name: {provider_name}")
|
|
|
|
# Check if enabled in config
|
|
is_enabled = self.config.is_provider_enabled(provider_name)
|
|
print(f" Provider {provider_name} enabled: {is_enabled}")
|
|
|
|
if is_enabled:
|
|
# Check if available (has API keys, etc.)
|
|
is_available = temp_provider.is_available()
|
|
print(f" Provider {provider_name} available: {is_available}")
|
|
|
|
if is_available:
|
|
# Set stop event and add to providers list
|
|
temp_provider.set_stop_event(self.stop_event)
|
|
self.providers.append(temp_provider)
|
|
print(f" ✓ {temp_provider.get_display_name()} provider initialized successfully")
|
|
else:
|
|
print(f" - {temp_provider.get_display_name()} provider is not available (missing API key or other requirement)")
|
|
else:
|
|
print(f" - {temp_provider.get_display_name()} provider is disabled in config")
|
|
|
|
except Exception as e:
|
|
print(f" ✗ Failed to initialize provider class {class_name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
except Exception as e:
|
|
print(f" ✗ Failed to load module {module_name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print(f"Total providers initialized: {len(self.providers)}")
|
|
for provider in self.providers:
|
|
print(f" - {provider.get_display_name()} ({provider.get_name()})")
|
|
|
|
if len(self.providers) == 0:
|
|
print("WARNING: No providers were initialized!")
|
|
elif len(self.providers) == 1 and self.providers[0].get_name() == 'dns':
|
|
print("WARNING: Only DNS provider initialized - other providers may have failed to load")
|
|
|
|
def start_scan(self, target_domain: str, max_depth: int = 2, clear_graph: bool = True) -> bool:
|
|
"""Start a new reconnaissance scan with task-based completion model."""
|
|
print(f"=== STARTING ENHANCED SCAN IN SCANNER {id(self)} ===")
|
|
print(f"Session ID: {self.session_id}")
|
|
print(f"Initial scanner status: {self.status}")
|
|
print(f"Clear graph: {clear_graph}")
|
|
|
|
# Generate scan ID based on clear_graph behavior
|
|
import uuid
|
|
|
|
if clear_graph:
|
|
# NEW SCAN: Generate new ID and terminate existing scan
|
|
print("NEW SCAN: Generating new scan ID and terminating existing scan")
|
|
self.current_scan_id = str(uuid.uuid4())[:8]
|
|
|
|
# Aggressive cleanup of previous scan
|
|
if self.scan_thread and self.scan_thread.is_alive():
|
|
print("Terminating previous scan thread...")
|
|
self._set_stop_signal()
|
|
|
|
if self.task_manager:
|
|
self.task_manager.stop_execution()
|
|
|
|
self.scan_thread.join(timeout=8.0)
|
|
if self.scan_thread.is_alive():
|
|
print("WARNING: Previous scan thread did not terminate cleanly")
|
|
|
|
else:
|
|
# ADD TO GRAPH: Keep existing scan ID if scan is running, or generate new one
|
|
if self.status == ScanStatus.RUNNING and self.current_scan_id:
|
|
print(f"ADD TO GRAPH: Keeping existing scan ID {self.current_scan_id}")
|
|
# Don't terminate existing scan - we're adding to it
|
|
else:
|
|
print("ADD TO GRAPH: No active scan, generating new scan ID")
|
|
self.current_scan_id = str(uuid.uuid4())[:8]
|
|
|
|
print(f"Using scan ID: {self.current_scan_id}")
|
|
|
|
# Reset state for new scan (but preserve graph if clear_graph=False)
|
|
if clear_graph or self.status != ScanStatus.RUNNING:
|
|
self.status = ScanStatus.IDLE
|
|
self._update_session_state()
|
|
|
|
try:
|
|
if not hasattr(self, 'providers') or not self.providers:
|
|
print(f"ERROR: No providers available in scanner {id(self)}, cannot start scan")
|
|
return False
|
|
|
|
print(f"Scanner {id(self)} validation passed, providers available: {[p.get_name() for p in self.providers]}")
|
|
|
|
if clear_graph:
|
|
self.graph.clear()
|
|
|
|
self.current_target = target_domain.lower().strip()
|
|
self.max_depth = max_depth
|
|
self.current_depth = 0
|
|
|
|
# Clear stop signals only if starting new scan
|
|
if clear_graph or self.status != ScanStatus.RUNNING:
|
|
self.stop_event.clear()
|
|
if self.session_id:
|
|
from core.session_manager import session_manager
|
|
session_manager.clear_stop_signal(self.session_id)
|
|
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.current_indicator = self.current_target
|
|
self.scan_start_time = datetime.now(timezone.utc)
|
|
self.scan_end_time = None
|
|
|
|
self._update_session_state()
|
|
|
|
# Initialize forensic session only for new scans
|
|
if clear_graph:
|
|
self.logger = new_session()
|
|
|
|
# Start task-based scan thread
|
|
print(f"Starting task-based scan thread with scan ID {self.current_scan_id}...")
|
|
self.scan_thread = threading.Thread(
|
|
target=self._execute_task_based_scan,
|
|
args=(self.current_target, max_depth, self.current_scan_id),
|
|
daemon=True
|
|
)
|
|
self.scan_thread.start()
|
|
|
|
print(f"=== ENHANCED SCAN STARTED SUCCESSFULLY IN SCANNER {id(self)} ===")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self.scan_end_time = datetime.now(timezone.utc)
|
|
self._update_session_state()
|
|
return False
|
|
|
|
def _execute_task_based_scan(self, target_domain: str, max_depth: int, scan_id: str) -> None:
|
|
"""Execute the reconnaissance scan using the task-based completion model."""
|
|
print(f"_execute_task_based_scan started for {target_domain} with depth {max_depth}, scan ID {scan_id}")
|
|
|
|
try:
|
|
self.status = ScanStatus.RUNNING
|
|
self._update_session_state()
|
|
|
|
enabled_providers = [provider.get_name() for provider in self.providers]
|
|
self.logger.log_scan_start(target_domain, max_depth, enabled_providers)
|
|
|
|
# Initialize task manager
|
|
self.task_manager = TaskManager(
|
|
self.providers,
|
|
self.graph,
|
|
self.logger,
|
|
max_concurrent_tasks=self.max_workers
|
|
)
|
|
|
|
# Add initial target to graph
|
|
self.graph.add_node(target_domain, NodeType.DOMAIN)
|
|
|
|
# Start task execution
|
|
self.task_manager.start_execution(max_workers=self.max_workers)
|
|
|
|
# Track processed targets to avoid duplicates
|
|
processed_targets = set()
|
|
|
|
# Task queue for breadth-first processing
|
|
target_queue = deque([(target_domain, 0)]) # (target, depth)
|
|
|
|
while target_queue:
|
|
# Abort if scan ID changed (new scan started)
|
|
if self.current_scan_id != scan_id:
|
|
print(f"Scan aborted - ID mismatch (current: {self.current_scan_id}, expected: {scan_id})")
|
|
break
|
|
|
|
if self._is_stop_requested():
|
|
print("Stop requested, terminating task-based scan.")
|
|
break
|
|
|
|
target, depth = target_queue.popleft()
|
|
|
|
if target in processed_targets or depth > max_depth:
|
|
continue
|
|
|
|
self.current_depth = depth
|
|
self.current_indicator = target
|
|
self._update_session_state()
|
|
|
|
print(f"Processing target: {target} at depth {depth}")
|
|
|
|
# Create tasks for all eligible providers
|
|
task_ids = self.task_manager.create_provider_tasks(target, depth, self.providers)
|
|
|
|
if task_ids:
|
|
print(f"Created {len(task_ids)} tasks for target {target}")
|
|
self.total_indicators_found += len(task_ids)
|
|
self._update_session_state()
|
|
|
|
processed_targets.add(target)
|
|
|
|
# Wait for current batch of tasks to complete before processing next depth
|
|
# This ensures we get all relationships before expanding further
|
|
timeout_per_batch = 60 # 60 seconds per batch
|
|
batch_start = time.time()
|
|
|
|
while time.time() - batch_start < timeout_per_batch:
|
|
if self._is_stop_requested() or self.current_scan_id != scan_id:
|
|
break
|
|
|
|
progress_report = self.task_manager.get_progress_report()
|
|
stats = progress_report['statistics']
|
|
|
|
# Update progress tracking
|
|
self.indicators_processed = stats['succeeded'] + stats['failed_permanent']
|
|
self._update_session_state()
|
|
|
|
# Check if current batch is complete
|
|
current_batch_complete = (
|
|
stats['pending'] == 0 and
|
|
stats['running'] == 0 and
|
|
stats['failed_retrying'] == 0
|
|
)
|
|
|
|
if current_batch_complete:
|
|
print(f"Batch complete for {target}: {stats['succeeded']} succeeded, {stats['failed_permanent']} failed")
|
|
break
|
|
|
|
time.sleep(1.0) # Check every second
|
|
|
|
# Collect new targets from completed successful tasks
|
|
if depth < max_depth:
|
|
new_targets = self._collect_new_targets_from_completed_tasks()
|
|
for new_target in new_targets:
|
|
if new_target not in processed_targets:
|
|
target_queue.append((new_target, depth + 1))
|
|
print(f"Added new target for next depth: {new_target}")
|
|
|
|
# Wait for all remaining tasks to complete
|
|
print("Waiting for all tasks to complete...")
|
|
final_completion = self.task_manager.wait_for_completion(timeout_seconds=300)
|
|
|
|
if not final_completion:
|
|
print("WARNING: Some tasks did not complete within timeout")
|
|
|
|
# Final progress update
|
|
final_report = self.task_manager.get_progress_report()
|
|
final_stats = final_report['statistics']
|
|
|
|
print(f"Final task statistics:")
|
|
print(f" - Total tasks: {final_stats['total_tasks']}")
|
|
print(f" - Succeeded: {final_stats['succeeded']}")
|
|
print(f" - Failed permanently: {final_stats['failed_permanent']}")
|
|
print(f" - Completion rate: {final_stats['completion_rate']:.1f}%")
|
|
|
|
# Determine final scan status
|
|
if self.current_scan_id == scan_id:
|
|
if self._is_stop_requested():
|
|
self.status = ScanStatus.STOPPED
|
|
elif final_stats['failed_permanent'] > 0 and final_stats['succeeded'] == 0:
|
|
self.status = ScanStatus.FAILED
|
|
elif final_stats['completion_rate'] < 50.0: # Less than 50% success rate
|
|
self.status = ScanStatus.FAILED
|
|
else:
|
|
self.status = ScanStatus.COMPLETED
|
|
|
|
self.scan_end_time = datetime.now(timezone.utc)
|
|
self._update_session_state()
|
|
self.logger.log_scan_complete()
|
|
else:
|
|
print(f"Scan completed but ID mismatch - not updating final status")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Task-based scan execution failed: {e}")
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self.scan_end_time = datetime.now(timezone.utc)
|
|
self.logger.logger.error(f"Task-based scan failed: {e}")
|
|
finally:
|
|
# Clean up task manager
|
|
if self.task_manager:
|
|
self.task_manager.stop_execution()
|
|
|
|
# Final statistics
|
|
graph_stats = self.graph.get_statistics()
|
|
print("Final scan statistics:")
|
|
print(f" - Total nodes: {graph_stats['basic_metrics']['total_nodes']}")
|
|
print(f" - Total edges: {graph_stats['basic_metrics']['total_edges']}")
|
|
print(f" - Targets processed: {len(processed_targets)}")
|
|
|
|
def _collect_new_targets_from_completed_tasks(self) -> Set[str]:
|
|
"""Collect new targets from successfully completed tasks."""
|
|
new_targets = set()
|
|
|
|
if not self.task_manager:
|
|
return new_targets
|
|
|
|
# Get task summaries to find successful tasks
|
|
task_summaries = self.task_manager.task_queue.get_task_summaries()
|
|
|
|
for task_summary in task_summaries:
|
|
if task_summary['status'] == 'succeeded':
|
|
task_id = task_summary['task_id']
|
|
task = self.task_manager.task_queue.tasks.get(task_id)
|
|
|
|
if task and task.result and task.result.data:
|
|
task_new_targets = task.result.data.get('new_targets', [])
|
|
for target in task_new_targets:
|
|
if _is_valid_domain(target) or _is_valid_ip(target):
|
|
new_targets.add(target)
|
|
|
|
return new_targets
|
|
|
|
def _update_session_state(self) -> None:
|
|
"""
|
|
Update the scanner state in Redis for GUI updates.
|
|
"""
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
success = session_manager.update_session_scanner(self.session_id, self)
|
|
if not success:
|
|
print(f"WARNING: Failed to update session state for {self.session_id}")
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update session state: {e}")
|
|
|
|
def stop_scan(self) -> bool:
|
|
"""Request immediate scan termination with task manager cleanup."""
|
|
try:
|
|
print("=== INITIATING ENHANCED SCAN TERMINATION ===")
|
|
self.logger.logger.info("Enhanced scan termination requested by user")
|
|
|
|
# Invalidate current scan ID to prevent stale updates
|
|
old_scan_id = self.current_scan_id
|
|
self.current_scan_id = None
|
|
print(f"Invalidated scan ID {old_scan_id}")
|
|
|
|
# Set stop signals
|
|
self._set_stop_signal()
|
|
self.status = ScanStatus.STOPPED
|
|
self.scan_end_time = datetime.now(timezone.utc)
|
|
|
|
# Immediately update GUI with stopped status
|
|
self._update_session_state()
|
|
|
|
# Stop task manager if running
|
|
if self.task_manager:
|
|
print("Stopping task manager...")
|
|
self.task_manager.stop_execution()
|
|
|
|
print("Enhanced termination signals sent. The scan will stop as soon as possible.")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Exception in enhanced stop_scan: {e}")
|
|
self.logger.logger.error(f"Error during enhanced scan termination: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def get_scan_status(self) -> Dict[str, Any]:
|
|
"""Get current scan status with enhanced task-based information."""
|
|
try:
|
|
status = {
|
|
'status': self.status,
|
|
'target_domain': self.current_target,
|
|
'current_depth': self.current_depth,
|
|
'max_depth': self.max_depth,
|
|
'current_indicator': self.current_indicator,
|
|
'total_indicators_found': self.total_indicators_found,
|
|
'indicators_processed': self.indicators_processed,
|
|
'progress_percentage': self._calculate_progress(),
|
|
'enabled_providers': [provider.get_name() for provider in self.providers],
|
|
'graph_statistics': self.graph.get_statistics(),
|
|
'scan_duration_seconds': self._calculate_scan_duration(),
|
|
'scan_start_time': self.scan_start_time.isoformat() if self.scan_start_time else None,
|
|
'scan_end_time': self.scan_end_time.isoformat() if self.scan_end_time else None
|
|
}
|
|
|
|
# Add task manager statistics if available
|
|
if self.task_manager:
|
|
progress_report = self.task_manager.get_progress_report()
|
|
status['task_statistics'] = progress_report['statistics']
|
|
status['task_details'] = {
|
|
'is_running': progress_report['is_running'],
|
|
'worker_count': progress_report['worker_count'],
|
|
'failed_tasks_count': len(progress_report['failed_tasks'])
|
|
}
|
|
|
|
# Update indicators processed from task statistics
|
|
task_stats = progress_report['statistics']
|
|
status['indicators_processed'] = task_stats['succeeded'] + task_stats['failed_permanent']
|
|
|
|
# Recalculate progress based on task completion
|
|
if task_stats['total_tasks'] > 0:
|
|
task_completion_rate = (task_stats['succeeded'] + task_stats['failed_permanent']) / task_stats['total_tasks']
|
|
status['progress_percentage'] = min(100.0, task_completion_rate * 100.0)
|
|
|
|
return status
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Exception in get_scan_status: {e}")
|
|
traceback.print_exc()
|
|
return {
|
|
'status': 'error',
|
|
'target_domain': None,
|
|
'current_depth': 0,
|
|
'max_depth': 0,
|
|
'current_indicator': '',
|
|
'total_indicators_found': 0,
|
|
'indicators_processed': 0,
|
|
'progress_percentage': 0.0,
|
|
'enabled_providers': [],
|
|
'graph_statistics': {},
|
|
'scan_duration_seconds': 0,
|
|
'error': str(e)
|
|
}
|
|
|
|
def _calculate_progress(self) -> float:
|
|
"""Calculate scan progress percentage."""
|
|
if self.total_indicators_found == 0:
|
|
return 0.0
|
|
return min(100.0, (self.indicators_processed / self.total_indicators_found) * 100)
|
|
|
|
def _calculate_scan_duration(self) -> float:
|
|
"""Calculate scan duration in seconds."""
|
|
if not self.scan_start_time:
|
|
return 0.0
|
|
|
|
end_time = self.scan_end_time or datetime.now(timezone.utc)
|
|
duration = (end_time - self.scan_start_time).total_seconds()
|
|
return round(duration, 2)
|
|
|
|
def get_graph_data(self) -> Dict[str, Any]:
|
|
"""Get current graph data for visualization."""
|
|
return self.graph.get_graph_data()
|
|
|
|
def export_results(self) -> Dict[str, Any]:
|
|
"""Export complete scan results with enhanced task-based audit trail."""
|
|
graph_data = self.graph.export_json()
|
|
audit_trail = self.logger.export_audit_trail()
|
|
provider_stats = {}
|
|
for provider in self.providers:
|
|
provider_stats[provider.get_name()] = provider.get_statistics()
|
|
|
|
export_data = {
|
|
'scan_metadata': {
|
|
'target_domain': self.current_target,
|
|
'max_depth': self.max_depth,
|
|
'final_status': self.status,
|
|
'total_indicators_processed': self.indicators_processed,
|
|
'enabled_providers': list(provider_stats.keys()),
|
|
'session_id': self.session_id,
|
|
'scan_id': self.current_scan_id,
|
|
'scan_duration_seconds': self._calculate_scan_duration(),
|
|
'scan_start_time': self.scan_start_time.isoformat() if self.scan_start_time else None,
|
|
'scan_end_time': self.scan_end_time.isoformat() if self.scan_end_time else None
|
|
},
|
|
'graph_data': graph_data,
|
|
'forensic_audit': audit_trail,
|
|
'provider_statistics': provider_stats,
|
|
'scan_summary': self.logger.get_forensic_summary()
|
|
}
|
|
|
|
# Add task execution details if available
|
|
if self.task_manager:
|
|
progress_report = self.task_manager.get_progress_report()
|
|
export_data['task_execution'] = {
|
|
'statistics': progress_report['statistics'],
|
|
'failed_tasks': progress_report['failed_tasks'],
|
|
'execution_summary': {
|
|
'total_tasks_created': progress_report['statistics']['total_tasks'],
|
|
'success_rate': progress_report['statistics']['completion_rate'],
|
|
'average_retries': self._calculate_average_retries(progress_report)
|
|
}
|
|
}
|
|
|
|
return export_data
|
|
|
|
def _calculate_average_retries(self, progress_report: Dict[str, Any]) -> float:
|
|
"""Calculate average retry attempts across all tasks."""
|
|
if not self.task_manager or not hasattr(self.task_manager.task_queue, 'tasks'):
|
|
return 0.0
|
|
|
|
total_attempts = 0
|
|
task_count = 0
|
|
|
|
for task in self.task_manager.task_queue.tasks.values():
|
|
if hasattr(task, 'execution_history'):
|
|
total_attempts += len(task.execution_history)
|
|
task_count += 1
|
|
|
|
return round(total_attempts / task_count, 2) if task_count > 0 else 0.0
|
|
|
|
def get_provider_statistics(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Get statistics for all providers with enhanced cache information."""
|
|
stats = {}
|
|
for provider in self.providers:
|
|
provider_stats = provider.get_statistics()
|
|
# Add cache performance metrics
|
|
if hasattr(provider, 'cache'):
|
|
cache_performance = {
|
|
'cache_enabled': True,
|
|
'cache_directory': provider.cache.cache_dir,
|
|
'cache_expiry_hours': provider.cache.cache_expiry / 3600
|
|
}
|
|
provider_stats.update(cache_performance)
|
|
stats[provider.get_name()] = provider_stats
|
|
return stats
|
|
|
|
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Get information about all available providers with enhanced details."""
|
|
info = {}
|
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
|
for filename in os.listdir(provider_dir):
|
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
|
module_name = f"providers.{filename[:-3]}"
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
|
provider_class = attribute
|
|
# Instantiate to get metadata, even if not fully configured
|
|
temp_provider = provider_class(session_config=self.config)
|
|
provider_name = temp_provider.get_name()
|
|
|
|
# Find the actual provider instance if it exists, to get live stats
|
|
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
|
|
|
provider_info = {
|
|
'display_name': temp_provider.get_display_name(),
|
|
'requires_api_key': temp_provider.requires_api_key(),
|
|
'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
|
|
'enabled': self.config.is_provider_enabled(provider_name),
|
|
'rate_limit': self.config.get_rate_limit(provider_name),
|
|
'eligibility': temp_provider.get_eligibility()
|
|
}
|
|
|
|
# Add cache information if provider has caching
|
|
if live_provider and hasattr(live_provider, 'cache'):
|
|
provider_info['cache_info'] = {
|
|
'cache_enabled': True,
|
|
'cache_directory': live_provider.cache.cache_dir,
|
|
'cache_expiry_hours': live_provider.cache.cache_expiry / 3600
|
|
}
|
|
|
|
info[provider_name] = provider_info
|
|
|
|
except Exception as e:
|
|
print(f"✗ Failed to get info for provider from {filename}: {e}")
|
|
traceback.print_exc()
|
|
return info |