1309 lines
59 KiB
Python
1309 lines
59 KiB
Python
# DNScope-reduced/core/scanner.py
|
|
|
|
import threading
|
|
import traceback
|
|
import os
|
|
import importlib
|
|
import redis
|
|
import time
|
|
import math
|
|
import random # Imported for jitter
|
|
from typing import List, Set, Dict, Any, Tuple, Optional
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from collections import defaultdict
|
|
from queue import PriorityQueue
|
|
from datetime import datetime, timezone
|
|
|
|
from core.graph_manager import GraphManager, NodeType
|
|
from core.logger import get_forensic_logger, new_session
|
|
from core.provider_result import ProviderResult
|
|
from utils.helpers import _is_valid_ip, _is_valid_domain
|
|
from utils.export_manager import export_manager
|
|
from providers.base_provider import BaseProvider
|
|
from providers.correlation_provider import CorrelationProvider
|
|
from core.rate_limiter import GlobalRateLimiter
|
|
|
|
class ScanStatus:
|
|
"""Enumeration of scan statuses."""
|
|
IDLE = "idle"
|
|
RUNNING = "running"
|
|
FINALIZING = "finalizing" # New state for post-scan analysis
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class Scanner:
|
|
"""
|
|
Main scanning orchestrator for DNScope passive reconnaissance.
|
|
UNIFIED: Combines comprehensive features with improved display formatting.
|
|
"""
|
|
|
|
def __init__(self, session_config=None):
|
|
"""Initialize scanner with session-specific configuration."""
|
|
try:
|
|
# Use provided session config or create default
|
|
if session_config is None:
|
|
from core.session_config import create_session_config
|
|
session_config = create_session_config()
|
|
|
|
self.config = session_config
|
|
self.graph = GraphManager()
|
|
self.providers = []
|
|
self.status = ScanStatus.IDLE
|
|
self.current_target = None
|
|
self.current_depth = 0
|
|
self.max_depth = 2
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
self.session_id: Optional[str] = None # Will be set by session manager
|
|
self.task_queue = PriorityQueue()
|
|
self.target_retries = defaultdict(int)
|
|
self.scan_failed_due_to_retries = False
|
|
self.initial_targets = set()
|
|
|
|
# Thread-safe processing tracking (from Document 1)
|
|
self.currently_processing = set()
|
|
self.processing_lock = threading.Lock()
|
|
# Display-friendly processing list (from Document 2)
|
|
self.currently_processing_display = []
|
|
|
|
# Scanning progress tracking
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.indicators_completed = 0
|
|
self.tasks_re_enqueued = 0
|
|
self.tasks_skipped = 0 # BUGFIX: Initialize tasks_skipped
|
|
self.total_tasks_ever_enqueued = 0
|
|
self.current_indicator = ""
|
|
self.last_task_from_queue = None
|
|
|
|
# Concurrent processing configuration
|
|
self.max_workers = self.config.max_concurrent_requests
|
|
self.executor = None
|
|
|
|
# Status logger thread with improved formatting
|
|
self.status_logger_thread = None
|
|
self.status_logger_stop_event = threading.Event()
|
|
|
|
# Initialize providers with session config
|
|
self._initialize_providers()
|
|
|
|
# Initialize logger
|
|
self.logger = get_forensic_logger()
|
|
|
|
# Initialize global rate limiter
|
|
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Scanner initialization failed: {e}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
def _is_stop_requested(self) -> bool:
|
|
"""
|
|
Check if stop is requested using both local and Redis-based signals.
|
|
This ensures reliable termination across process boundaries.
|
|
"""
|
|
if self.stop_event.is_set():
|
|
return True
|
|
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
return session_manager.is_stop_requested(self.session_id)
|
|
except Exception as e:
|
|
# Fall back to local event
|
|
return self.stop_event.is_set()
|
|
|
|
return self.stop_event.is_set()
|
|
|
|
def _set_stop_signal(self) -> None:
|
|
"""
|
|
Set stop signal both locally and in Redis.
|
|
"""
|
|
self.stop_event.set()
|
|
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
session_manager.set_stop_signal(self.session_id)
|
|
except Exception as e:
|
|
pass
|
|
|
|
def __getstate__(self):
|
|
"""Prepare object for pickling by excluding unpicklable attributes."""
|
|
state = self.__dict__.copy()
|
|
|
|
unpicklable_attrs = [
|
|
'stop_event',
|
|
'scan_thread',
|
|
'executor',
|
|
'processing_lock',
|
|
'task_queue',
|
|
'rate_limiter',
|
|
'logger',
|
|
'status_logger_thread',
|
|
'status_logger_stop_event'
|
|
]
|
|
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
|
|
if 'providers' in state:
|
|
for provider in state['providers']:
|
|
if hasattr(provider, '_stop_event'):
|
|
provider._stop_event = None
|
|
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore object after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
self.executor = None
|
|
self.processing_lock = threading.Lock()
|
|
self.task_queue = PriorityQueue()
|
|
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
|
self.logger = get_forensic_logger()
|
|
self.status_logger_thread = None
|
|
self.status_logger_stop_event = threading.Event()
|
|
|
|
if not hasattr(self, 'providers') or not self.providers:
|
|
self._initialize_providers()
|
|
|
|
if not hasattr(self, 'currently_processing'):
|
|
self.currently_processing = set()
|
|
|
|
if not hasattr(self, 'currently_processing_display'):
|
|
self.currently_processing_display = []
|
|
|
|
if hasattr(self, 'providers'):
|
|
for provider in self.providers:
|
|
if hasattr(provider, 'set_stop_event'):
|
|
provider.set_stop_event(self.stop_event)
|
|
|
|
def _initialize_providers(self) -> None:
|
|
"""Initialize all available providers based on session configuration."""
|
|
self.providers = []
|
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
|
|
|
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
|
|
|
|
correlation_provider_instance = None
|
|
|
|
for filename in os.listdir(provider_dir):
|
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
|
module_name = f"providers.{filename[:-3]}"
|
|
try:
|
|
print(f"Loading provider module: {module_name}")
|
|
module = importlib.import_module(module_name)
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
|
provider_class = attribute
|
|
provider = provider_class(name=attribute_name, session_config=self.config)
|
|
provider_name = provider.get_name()
|
|
|
|
print(f" Provider: {provider_name}")
|
|
print(f" Class: {provider_class.__name__}")
|
|
print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}")
|
|
print(f" Requires API key: {provider.requires_api_key()}")
|
|
|
|
if provider.requires_api_key():
|
|
api_key = self.config.get_api_key(provider_name)
|
|
print(f" API key present: {'Yes' if api_key else 'No'}")
|
|
if api_key:
|
|
print(f" API key preview: {api_key[:8]}...")
|
|
|
|
if self.config.is_provider_enabled(provider_name):
|
|
is_available = provider.is_available()
|
|
print(f" Available: {is_available}")
|
|
|
|
if is_available:
|
|
provider.set_stop_event(self.stop_event)
|
|
|
|
# Special handling for correlation provider
|
|
if isinstance(provider, CorrelationProvider):
|
|
provider.set_graph_manager(self.graph)
|
|
correlation_provider_instance = provider
|
|
print(f" ✓ Correlation provider configured with graph manager")
|
|
|
|
self.providers.append(provider)
|
|
print(f" ✓ Added to scanner")
|
|
else:
|
|
print(f" ✗ Not available - skipped")
|
|
else:
|
|
print(f" ✗ Disabled in config - skipped")
|
|
|
|
except Exception as e:
|
|
print(f" ERROR loading {module_name}: {e}")
|
|
traceback.print_exc()
|
|
|
|
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
|
|
print(f"Active providers: {[p.get_name() for p in self.providers]}")
|
|
print(f"Provider count: {len(self.providers)}")
|
|
|
|
# Verify correlation provider is properly configured
|
|
if correlation_provider_instance:
|
|
print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}")
|
|
|
|
print("=" * 50)
|
|
|
|
def _status_logger_thread(self):
|
|
"""Periodically prints a clean, formatted scan status to the terminal."""
|
|
HEADER = "\033[95m"
|
|
CYAN = "\033[96m"
|
|
GREEN = "\033[92m"
|
|
YELLOW = "\033[93m"
|
|
BLUE = "\033[94m"
|
|
ENDC = "\033[0m"
|
|
BOLD = "\033[1m"
|
|
|
|
last_status_str = ""
|
|
while not self.status_logger_stop_event.is_set():
|
|
try:
|
|
with self.processing_lock:
|
|
in_flight_tasks = list(self.currently_processing)
|
|
self.currently_processing_display = in_flight_tasks.copy()
|
|
|
|
status_str = (
|
|
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
|
|
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
|
|
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
|
|
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
|
|
f"Skipped: {self.tasks_skipped} | "
|
|
f"Rescheduled: {self.tasks_re_enqueued}"
|
|
)
|
|
|
|
if status_str != last_status_str:
|
|
print(f"\n{'-'*80}")
|
|
print(status_str)
|
|
if self.last_task_from_queue:
|
|
# Unpack the new time-based queue item
|
|
_, p, (pn, ti, d) = self.last_task_from_queue
|
|
print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}")
|
|
if in_flight_tasks:
|
|
print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}")
|
|
display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]]
|
|
print("\n".join(display_tasks))
|
|
if len(in_flight_tasks) > 3:
|
|
print(f" ... and {len(in_flight_tasks) - 3} more")
|
|
print(f"{'-'*80}")
|
|
last_status_str = status_str
|
|
except Exception:
|
|
pass
|
|
|
|
time.sleep(2)
|
|
|
|
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
|
|
if self.scan_thread and self.scan_thread.is_alive():
|
|
self.logger.logger.info("Stopping existing scan before starting new one")
|
|
self._set_stop_signal()
|
|
self.status = ScanStatus.STOPPED
|
|
|
|
# Clean up processing state
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
# Clear task queue
|
|
while not self.task_queue.empty():
|
|
try:
|
|
self.task_queue.get_nowait()
|
|
except:
|
|
break
|
|
|
|
# Shutdown executor
|
|
if self.executor:
|
|
try:
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
except:
|
|
pass
|
|
finally:
|
|
self.executor = None
|
|
|
|
# Wait for scan thread to finish (with timeout)
|
|
self.scan_thread.join(timeout=5.0)
|
|
if self.scan_thread.is_alive():
|
|
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
|
|
|
|
self.status = ScanStatus.IDLE
|
|
self.stop_event.clear()
|
|
|
|
if self.session_id:
|
|
from core.session_manager import session_manager
|
|
session_manager.clear_stop_signal(self.session_id)
|
|
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
self.task_queue = PriorityQueue()
|
|
self.target_retries.clear()
|
|
self.scan_failed_due_to_retries = False
|
|
self.tasks_skipped = 0
|
|
self.last_task_from_queue = None
|
|
|
|
self._update_session_state()
|
|
|
|
try:
|
|
if not hasattr(self, 'providers') or not self.providers:
|
|
self.logger.logger.error("No providers available for scanning")
|
|
return False
|
|
|
|
available_providers = [p for p in self.providers if p.is_available()]
|
|
if not available_providers:
|
|
self.logger.logger.error("No providers are currently available/configured")
|
|
return False
|
|
|
|
if clear_graph:
|
|
self.graph.clear()
|
|
self.initial_targets.clear()
|
|
|
|
if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
|
|
try:
|
|
node_data = self.graph.graph.nodes[force_rescan_target]
|
|
if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
|
|
node_data['metadata']['provider_states'] = {}
|
|
self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}")
|
|
except Exception as e:
|
|
self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}")
|
|
|
|
target = target.lower().strip()
|
|
if not target:
|
|
self.logger.logger.error("Empty target provided")
|
|
return False
|
|
|
|
from utils.helpers import is_valid_target
|
|
if not is_valid_target(target):
|
|
self.logger.logger.error(f"Invalid target format: {target}")
|
|
return False
|
|
|
|
self.current_target = target
|
|
self.initial_targets.add(self.current_target)
|
|
self.max_depth = max(1, min(5, max_depth)) # Clamp depth between 1-5
|
|
self.current_depth = 0
|
|
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.indicators_completed = 0
|
|
self.tasks_re_enqueued = 0
|
|
self.total_tasks_ever_enqueued = 0
|
|
self.current_indicator = self.current_target
|
|
|
|
self._update_session_state()
|
|
self.logger = new_session()
|
|
|
|
try:
|
|
self.scan_thread = threading.Thread(
|
|
target=self._execute_scan,
|
|
args=(self.current_target, self.max_depth),
|
|
daemon=True,
|
|
name=f"ScanThread-{self.session_id or 'default'}"
|
|
)
|
|
self.scan_thread.start()
|
|
|
|
self.status_logger_stop_event.clear()
|
|
self.status_logger_thread = threading.Thread(
|
|
target=self._status_logger_thread,
|
|
daemon=True,
|
|
name=f"StatusLogger-{self.session_id or 'default'}"
|
|
)
|
|
self.status_logger_thread.start()
|
|
|
|
self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.logger.error(f"Error starting scan threads: {e}")
|
|
self.status = ScanStatus.FAILED
|
|
self._update_session_state()
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.logger.logger.error(f"Error in scan startup: {e}")
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self._update_session_state()
|
|
return False
|
|
|
|
def _get_priority(self, provider_name):
|
|
if provider_name == 'correlation':
|
|
return 100 # Highest priority number = lowest priority (runs last)
|
|
|
|
rate_limit = self.config.get_rate_limit(provider_name)
|
|
|
|
# Handle edge cases
|
|
if rate_limit <= 0:
|
|
return 90 # Very low priority for invalid/disabled providers
|
|
|
|
if provider_name == 'dns':
|
|
return 1 # DNS is fastest, should run first
|
|
elif provider_name == 'shodan':
|
|
return 3 # Shodan is medium speed, good priority
|
|
elif provider_name == 'crtsh':
|
|
return 5 # crt.sh is slower, lower priority
|
|
else:
|
|
# For any other providers, use rate limit as a guide
|
|
if rate_limit >= 100:
|
|
return 2 # High rate limit = high priority
|
|
elif rate_limit >= 50:
|
|
return 4 # Medium-high rate limit = medium-high priority
|
|
elif rate_limit >= 20:
|
|
return 6 # Medium rate limit = medium priority
|
|
elif rate_limit >= 5:
|
|
return 8 # Low rate limit = low priority
|
|
else:
|
|
return 10 # Very low rate limit = very low priority
|
|
|
|
def _execute_scan(self, target: str, max_depth: int) -> None:
|
|
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
|
processed_tasks = set()
|
|
|
|
is_ip = _is_valid_ip(target)
|
|
initial_providers = [p for p in self._get_eligible_providers(target, is_ip, False) if not isinstance(p, CorrelationProvider)]
|
|
|
|
for provider in initial_providers:
|
|
provider_name = provider.get_name()
|
|
priority = self._get_priority(provider_name)
|
|
self.task_queue.put((time.time(), priority, (provider_name, target, 0)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
|
|
try:
|
|
self.status = ScanStatus.RUNNING
|
|
self._update_session_state()
|
|
|
|
enabled_providers = [provider.get_name() for provider in self.providers]
|
|
self.logger.log_scan_start(target, max_depth, enabled_providers)
|
|
|
|
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
|
self.graph.add_node(target, node_type)
|
|
self._initialize_provider_states(target)
|
|
consecutive_empty_iterations = 0
|
|
max_empty_iterations = 50
|
|
|
|
print(f"\n=== PHASE 1: Running non-correlation providers ===")
|
|
while not self._is_stop_requested():
|
|
queue_empty = self.task_queue.empty()
|
|
with self.processing_lock:
|
|
no_active_processing = len(self.currently_processing) == 0
|
|
|
|
if queue_empty and no_active_processing:
|
|
consecutive_empty_iterations += 1
|
|
if consecutive_empty_iterations >= max_empty_iterations:
|
|
break
|
|
time.sleep(0.1)
|
|
continue
|
|
else:
|
|
consecutive_empty_iterations = 0
|
|
|
|
try:
|
|
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
|
|
if provider_name == 'correlation': continue
|
|
current_time = time.time()
|
|
if run_at > current_time:
|
|
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
|
|
time.sleep(min(0.5, run_at - current_time))
|
|
continue
|
|
except:
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
|
|
task_tuple = (provider_name, target_item, depth)
|
|
if task_tuple in processed_tasks or depth > max_depth:
|
|
self.tasks_skipped += 1
|
|
self.indicators_completed += 1
|
|
continue
|
|
|
|
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
|
|
defer_until = time.time() + 60
|
|
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
|
|
self.tasks_re_enqueued += 1
|
|
continue
|
|
|
|
with self.processing_lock:
|
|
if self._is_stop_requested(): break
|
|
processing_key = (provider_name, target_item)
|
|
if processing_key in self.currently_processing:
|
|
self.tasks_skipped += 1
|
|
self.indicators_completed += 1
|
|
continue
|
|
self.currently_processing.add(processing_key)
|
|
|
|
try:
|
|
self.current_depth = depth
|
|
self.current_indicator = target_item
|
|
self._update_session_state()
|
|
if self._is_stop_requested(): break
|
|
|
|
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
|
if provider and not isinstance(provider, CorrelationProvider):
|
|
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
|
|
if self._is_stop_requested(): break
|
|
|
|
if not success:
|
|
retry_key = (provider_name, target_item, depth)
|
|
self.target_retries[retry_key] += 1
|
|
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
|
|
retry_count = self.target_retries[retry_key]
|
|
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
|
|
self.task_queue.put((time.time() + backoff_delay, priority, (provider_name, target_item, depth)))
|
|
self.tasks_re_enqueued += 1
|
|
else:
|
|
self.scan_failed_due_to_retries = True
|
|
self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded")
|
|
else:
|
|
processed_tasks.add(task_tuple)
|
|
self.indicators_completed += 1
|
|
|
|
if not self._is_stop_requested():
|
|
for new_target in new_targets:
|
|
is_ip_new = _is_valid_ip(new_target)
|
|
eligible_providers_new = [p for p in self._get_eligible_providers(new_target, is_ip_new, False) if not isinstance(p, CorrelationProvider)]
|
|
for p_new in eligible_providers_new:
|
|
p_name_new = p_new.get_name()
|
|
new_depth = depth + 1
|
|
if (p_name_new, new_target, new_depth) not in processed_tasks and new_depth <= max_depth:
|
|
self.task_queue.put((time.time(), self._get_priority(p_name_new), (p_name_new, new_target, new_depth)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
else:
|
|
self.tasks_skipped += 1
|
|
self.indicators_completed += 1
|
|
finally:
|
|
with self.processing_lock:
|
|
self.currently_processing.discard((provider_name, target_item))
|
|
|
|
# This code runs after the main loop finishes or is stopped.
|
|
self.status = ScanStatus.FINALIZING
|
|
self._update_session_state()
|
|
self.logger.logger.info("Scan stopped or completed. Entering finalization phase.")
|
|
|
|
if self.status in [ScanStatus.FINALIZING, ScanStatus.COMPLETED, ScanStatus.STOPPED]:
|
|
print(f"\n=== PHASE 2: Running correlation analysis ===")
|
|
self._run_correlation_phase(max_depth, processed_tasks)
|
|
self._update_session_state()
|
|
|
|
# Determine the final status *after* finalization.
|
|
if self._is_stop_requested():
|
|
self.status = ScanStatus.STOPPED
|
|
elif self.scan_failed_due_to_retries:
|
|
self.status = ScanStatus.FAILED
|
|
else:
|
|
self.status = ScanStatus.COMPLETED
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self.logger.logger.error(f"Scan failed: {e}")
|
|
finally:
|
|
# The 'finally' block is now only for guaranteed cleanup.
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
while not self.task_queue.empty():
|
|
try: self.task_queue.get_nowait()
|
|
except: break
|
|
|
|
self.status_logger_stop_event.set()
|
|
if self.status_logger_thread and self.status_logger_thread.is_alive():
|
|
self.status_logger_thread.join(timeout=2.0)
|
|
|
|
# The executor shutdown now happens *after* the correlation phase has run.
|
|
if self.executor:
|
|
try:
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
except Exception as e:
|
|
self.logger.logger.warning(f"Error shutting down executor: {e}")
|
|
finally:
|
|
self.executor = None
|
|
|
|
self._update_session_state()
|
|
self.logger.log_scan_complete()
|
|
|
|
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
|
|
"""
|
|
PHASE 2: Run correlation analysis on all discovered nodes.
|
|
Enhanced with better error handling and progress tracking.
|
|
"""
|
|
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
|
|
if not correlation_provider:
|
|
print("No correlation provider found - skipping correlation phase")
|
|
return
|
|
|
|
# Ensure correlation provider has access to current graph state
|
|
correlation_provider.set_graph_manager(self.graph)
|
|
print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes")
|
|
|
|
# Get all nodes from the graph for correlation analysis
|
|
all_nodes = list(self.graph.graph.nodes())
|
|
correlation_tasks = []
|
|
correlation_tasks_enqueued = 0
|
|
|
|
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
|
|
|
|
for node_id in all_nodes:
|
|
# Determine appropriate depth for correlation (use 0 for simplicity)
|
|
correlation_depth = 0
|
|
task_tuple = ('correlation', node_id, correlation_depth)
|
|
|
|
# Don't re-process already processed correlation tasks
|
|
if task_tuple not in processed_tasks:
|
|
priority = self._get_priority('correlation')
|
|
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
|
|
correlation_tasks.append(task_tuple)
|
|
correlation_tasks_enqueued += 1
|
|
self.total_tasks_ever_enqueued += 1
|
|
|
|
print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks")
|
|
|
|
# Force session state update to reflect new task count
|
|
self._update_session_state()
|
|
|
|
# Process correlation tasks with enhanced tracking
|
|
consecutive_empty_iterations = 0
|
|
max_empty_iterations = 20
|
|
correlation_completed = 0
|
|
correlation_errors = 0
|
|
|
|
while correlation_tasks:
|
|
# Check if we should continue processing
|
|
queue_empty = self.task_queue.empty()
|
|
with self.processing_lock:
|
|
no_active_processing = len(self.currently_processing) == 0
|
|
|
|
if queue_empty and no_active_processing:
|
|
consecutive_empty_iterations += 1
|
|
if consecutive_empty_iterations >= max_empty_iterations:
|
|
print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining")
|
|
break
|
|
time.sleep(0.1)
|
|
continue
|
|
else:
|
|
consecutive_empty_iterations = 0
|
|
|
|
try:
|
|
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
|
|
|
|
# Only process correlation tasks in this phase
|
|
if provider_name != 'correlation':
|
|
continue
|
|
|
|
except:
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
task_tuple = (provider_name, target_item, depth)
|
|
|
|
# Skip if already processed
|
|
if task_tuple in processed_tasks:
|
|
self.tasks_skipped += 1
|
|
self.indicators_completed += 1
|
|
if task_tuple in correlation_tasks:
|
|
correlation_tasks.remove(task_tuple)
|
|
continue
|
|
|
|
with self.processing_lock:
|
|
processing_key = (provider_name, target_item)
|
|
if processing_key in self.currently_processing:
|
|
self.tasks_skipped += 1
|
|
self.indicators_completed += 1
|
|
continue
|
|
self.currently_processing.add(processing_key)
|
|
|
|
try:
|
|
self.current_indicator = target_item
|
|
self._update_session_state()
|
|
|
|
# Process correlation task with enhanced error handling
|
|
try:
|
|
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
|
|
|
|
if success:
|
|
processed_tasks.add(task_tuple)
|
|
correlation_completed += 1
|
|
self.indicators_completed += 1
|
|
if task_tuple in correlation_tasks:
|
|
correlation_tasks.remove(task_tuple)
|
|
else:
|
|
# For correlations, don't retry - just mark as completed
|
|
correlation_errors += 1
|
|
self.indicators_completed += 1
|
|
if task_tuple in correlation_tasks:
|
|
correlation_tasks.remove(task_tuple)
|
|
|
|
except Exception as e:
|
|
correlation_errors += 1
|
|
self.indicators_completed += 1
|
|
if task_tuple in correlation_tasks:
|
|
correlation_tasks.remove(task_tuple)
|
|
self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}")
|
|
|
|
finally:
|
|
with self.processing_lock:
|
|
processing_key = (provider_name, target_item)
|
|
self.currently_processing.discard(processing_key)
|
|
|
|
# Periodic progress update during correlation phase
|
|
if correlation_completed % 10 == 0 and correlation_completed > 0:
|
|
remaining = len(correlation_tasks)
|
|
print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining")
|
|
|
|
print(f"Correlation phase complete:")
|
|
print(f" - Successfully processed: {correlation_completed}")
|
|
print(f" - Errors encountered: {correlation_errors}")
|
|
print(f" - Tasks remaining: {len(correlation_tasks)}")
|
|
|
|
|
|
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
|
"""
|
|
Manages the entire process for a given target and provider.
|
|
This version is generalized to handle all relationships dynamically.
|
|
"""
|
|
if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
|
|
return set(), set(), False
|
|
|
|
is_ip = _is_valid_ip(target)
|
|
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
|
|
|
self.graph.add_node(target, target_type)
|
|
self._initialize_provider_states(target)
|
|
|
|
new_targets = set()
|
|
provider_successful = True
|
|
|
|
try:
|
|
provider_result = self._execute_provider_query(provider, target, is_ip)
|
|
|
|
if provider_result is None:
|
|
provider_successful = False
|
|
# Allow correlation provider to process results even if scan is stopped
|
|
elif not self._is_stop_requested() or isinstance(provider, CorrelationProvider):
|
|
# Pass all relationships to be processed
|
|
discovered, is_large_entity = self._process_provider_result_unified(
|
|
target, provider, provider_result, depth
|
|
)
|
|
new_targets.update(discovered)
|
|
|
|
except Exception as e:
|
|
provider_successful = False
|
|
self._log_provider_error(target, provider.get_name(), str(e))
|
|
|
|
return new_targets, set(), provider_successful
|
|
|
|
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
|
|
"""
|
|
The "worker" function that directly communicates with the provider to fetch data.
|
|
"""
|
|
provider_name = provider.get_name()
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
|
|
return None
|
|
|
|
try:
|
|
if is_ip:
|
|
result = provider.query_ip(target)
|
|
else:
|
|
result = provider.query_domain(target)
|
|
|
|
if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
|
|
return None
|
|
|
|
relationship_count = result.get_relationship_count() if result else 0
|
|
self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
|
|
return None
|
|
|
|
def _create_large_entity_from_result(self, source_node: str, provider_name: str,
|
|
provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
|
|
"""
|
|
Creates a large entity node, tags all member nodes, and stores original relationships.
|
|
FIXED: Now stores original relationships for later restoration during extraction.
|
|
"""
|
|
members = {rel.target_node for rel in provider_result.relationships
|
|
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
|
|
|
|
if not members:
|
|
return "", set()
|
|
|
|
large_entity_id = f"le_{provider_name}_{source_node}"
|
|
|
|
# FIXED: Store original relationships for each member
|
|
member_relationships = {}
|
|
for rel in provider_result.relationships:
|
|
if rel.target_node in members:
|
|
if rel.target_node not in member_relationships:
|
|
member_relationships[rel.target_node] = []
|
|
member_relationships[rel.target_node].append({
|
|
'source_node': rel.source_node,
|
|
'target_node': rel.target_node,
|
|
'relationship_type': rel.relationship_type,
|
|
'provider': rel.provider,
|
|
'raw_data': rel.raw_data
|
|
})
|
|
|
|
self.graph.add_node(
|
|
node_id=large_entity_id,
|
|
node_type=NodeType.LARGE_ENTITY,
|
|
attributes=[
|
|
{"name": "count", "value": len(members), "type": "statistic"},
|
|
{"name": "source_provider", "value": provider_name, "type": "metadata"},
|
|
{"name": "discovery_depth", "value": depth, "type": "metadata"},
|
|
{"name": "nodes", "value": list(members), "type": "metadata"},
|
|
{"name": "original_relationships", "value": member_relationships, "type": "metadata"} # FIXED: Store original relationships
|
|
],
|
|
description=f"A collection of {len(members)} nodes discovered from {source_node} via {provider_name}."
|
|
)
|
|
|
|
for member_id in members:
|
|
node_type = NodeType.IP if _is_valid_ip(member_id) else NodeType.DOMAIN
|
|
self.graph.add_node(
|
|
node_id=member_id,
|
|
node_type=node_type,
|
|
metadata={'large_entity_id': large_entity_id}
|
|
)
|
|
|
|
return large_entity_id, members
|
|
|
|
def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
|
|
"""
|
|
Removes a node from a large entity and restores its original relationships.
|
|
FIXED: Now restores original relationships to make the node reachable.
|
|
"""
|
|
if not self.graph.graph.has_node(node_id):
|
|
return False
|
|
|
|
node_data = self.graph.graph.nodes[node_id]
|
|
metadata = node_data.get('metadata', {})
|
|
|
|
if metadata.get('large_entity_id') != large_entity_id:
|
|
return False
|
|
|
|
# Remove the large entity tag
|
|
del metadata['large_entity_id']
|
|
self.graph.add_node(node_id, NodeType(node_data['type']), metadata=metadata)
|
|
|
|
# FIXED: Restore original relationships if they exist
|
|
if self.graph.graph.has_node(large_entity_id):
|
|
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
|
original_relationships_attr = next((a for a in le_attrs if a['name'] == 'original_relationships'), None)
|
|
|
|
if original_relationships_attr and node_id in original_relationships_attr['value']:
|
|
# Restore all original relationships for this node
|
|
for rel_data in original_relationships_attr['value'][node_id]:
|
|
self.graph.add_edge(
|
|
source_id=rel_data['source_node'],
|
|
target_id=rel_data['target_node'],
|
|
relationship_type=rel_data['relationship_type'],
|
|
source_provider=rel_data['provider'],
|
|
raw_data=rel_data['raw_data']
|
|
)
|
|
|
|
# Ensure both nodes exist in the graph
|
|
source_type = NodeType.IP if _is_valid_ip(rel_data['source_node']) else NodeType.DOMAIN
|
|
target_type = NodeType.IP if _is_valid_ip(rel_data['target_node']) else NodeType.DOMAIN
|
|
self.graph.add_node(rel_data['source_node'], source_type)
|
|
self.graph.add_node(rel_data['target_node'], target_type)
|
|
|
|
# Update the large entity to remove this node from its list
|
|
nodes_attr = next((a for a in le_attrs if a['name'] == 'nodes'), None)
|
|
if nodes_attr and node_id in nodes_attr['value']:
|
|
nodes_attr['value'].remove(node_id)
|
|
|
|
count_attr = next((a for a in le_attrs if a['name'] == 'count'), None)
|
|
if count_attr:
|
|
count_attr['value'] = max(0, count_attr['value'] - 1)
|
|
|
|
# Remove from original relationships tracking
|
|
if node_id in original_relationships_attr['value']:
|
|
del original_relationships_attr['value'][node_id]
|
|
|
|
# Re-enqueue the node for full processing
|
|
is_ip = _is_valid_ip(node_id)
|
|
eligible_providers = self._get_eligible_providers(node_id, is_ip, False, is_extracted=True)
|
|
for provider in eligible_providers:
|
|
provider_name = provider.get_name()
|
|
priority = self._get_priority(provider_name)
|
|
# Use current depth of the large entity if available, else 0
|
|
depth = 0
|
|
if self.graph.graph.has_node(large_entity_id):
|
|
le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
|
depth_attr = next((a for a in le_attrs if a['name'] == 'discovery_depth'), None)
|
|
if depth_attr:
|
|
depth = depth_attr['value']
|
|
|
|
self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
|
|
return True
|
|
|
|
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
|
|
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
|
|
"""
|
|
Process a unified ProviderResult object to update the graph.
|
|
This version dynamically re-routes edges to a large entity container.
|
|
"""
|
|
provider_name = provider.get_name()
|
|
discovered_targets = set()
|
|
large_entity_id = ""
|
|
large_entity_members = set()
|
|
|
|
# Stop processing for non-correlation providers if requested
|
|
if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
|
|
return discovered_targets, False
|
|
|
|
eligible_rel_count = sum(
|
|
1 for rel in provider_result.relationships if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)
|
|
)
|
|
is_large_entity = eligible_rel_count > self.config.large_entity_threshold
|
|
|
|
if is_large_entity:
|
|
large_entity_id, large_entity_members = self._create_large_entity_from_result(
|
|
target, provider_name, provider_result, current_depth
|
|
)
|
|
|
|
for i, relationship in enumerate(provider_result.relationships):
|
|
# Stop processing for non-correlation providers if requested
|
|
if i % 5 == 0 and self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
|
|
break
|
|
|
|
source_node_id = relationship.source_node
|
|
target_node_id = relationship.target_node
|
|
|
|
# Determine visual source and target, substituting with large entity ID if necessary
|
|
visual_source = large_entity_id if source_node_id in large_entity_members else source_node_id
|
|
visual_target = large_entity_id if target_node_id in large_entity_members else target_node_id
|
|
|
|
# Prevent self-loops on the large entity node
|
|
if visual_source == visual_target:
|
|
continue
|
|
|
|
# Determine node types for the actual nodes
|
|
source_type = NodeType.IP if _is_valid_ip(source_node_id) else NodeType.DOMAIN
|
|
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
|
|
target_type = NodeType.ISP
|
|
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
|
|
target_type = NodeType.CA
|
|
elif provider_name == 'correlation':
|
|
target_type = NodeType.CORRELATION_OBJECT
|
|
elif _is_valid_ip(target_node_id):
|
|
target_type = NodeType.IP
|
|
else:
|
|
target_type = NodeType.DOMAIN
|
|
|
|
max_depth_reached = current_depth >= self.max_depth
|
|
|
|
# Add actual nodes to the graph (they might be hidden by the UI)
|
|
self.graph.add_node(source_node_id, source_type)
|
|
self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached})
|
|
|
|
# Add the visual edge to the graph
|
|
self.graph.add_edge(
|
|
visual_source, visual_target,
|
|
relationship.relationship_type,
|
|
provider_name,
|
|
relationship.raw_data
|
|
)
|
|
|
|
if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
|
|
if target_node_id not in large_entity_members:
|
|
discovered_targets.add(target_node_id)
|
|
|
|
if large_entity_members:
|
|
self.logger.logger.info(f"Enqueuing DNS and Correlation for {len(large_entity_members)} members of {large_entity_id}")
|
|
for member in large_entity_members:
|
|
for provider_name_to_run in ['dns', 'correlation']:
|
|
p_instance = next((p for p in self.providers if p.get_name() == provider_name_to_run), None)
|
|
if p_instance and p_instance.get_eligibility().get('domains' if _is_valid_domain(member) else 'ips'):
|
|
priority = self._get_priority(provider_name_to_run)
|
|
self.task_queue.put((time.time(), priority, (provider_name_to_run, member, current_depth)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
|
|
attributes_by_node = defaultdict(list)
|
|
for attribute in provider_result.attributes:
|
|
attr_dict = {
|
|
"name": attribute.name, "value": attribute.value, "type": attribute.type,
|
|
"provider": attribute.provider, "metadata": attribute.metadata
|
|
}
|
|
attributes_by_node[attribute.target_node].append(attr_dict)
|
|
|
|
for node_id, node_attributes_list in attributes_by_node.items():
|
|
if not self.graph.graph.has_node(node_id):
|
|
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
|
|
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
|
|
else:
|
|
existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
|
|
self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
|
|
|
|
return discovered_targets, is_large_entity
|
|
|
|
def stop_scan(self) -> bool:
|
|
"""Request immediate scan termination with proper cleanup."""
|
|
try:
|
|
self.logger.logger.info("Scan termination requested by user")
|
|
self._set_stop_signal()
|
|
self.status = ScanStatus.STOPPED
|
|
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
self.task_queue = PriorityQueue()
|
|
|
|
if self.executor:
|
|
try:
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
except Exception:
|
|
pass
|
|
|
|
self._update_session_state()
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.logger.error(f"Error during scan termination: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def _update_session_state(self) -> None:
|
|
"""
|
|
Update the scanner state in Redis for GUI updates.
|
|
"""
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
session_manager.update_session_scanner(self.session_id, self)
|
|
except Exception:
|
|
pass
|
|
|
|
def get_scan_status(self) -> Dict[str, Any]:
|
|
"""Get current scan status with comprehensive processing information."""
|
|
try:
|
|
with self.processing_lock:
|
|
currently_processing_count = len(self.currently_processing)
|
|
currently_processing_list = list(self.currently_processing)
|
|
|
|
return {
|
|
'status': self.status,
|
|
'target_domain': self.current_target,
|
|
'current_depth': self.current_depth,
|
|
'max_depth': self.max_depth,
|
|
'current_indicator': self.current_indicator,
|
|
'indicators_processed': self.indicators_processed,
|
|
'indicators_completed': self.indicators_completed,
|
|
'tasks_re_enqueued': self.tasks_re_enqueued,
|
|
'progress_percentage': self._calculate_progress(),
|
|
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
|
|
'enabled_providers': [provider.get_name() for provider in self.providers],
|
|
'graph_statistics': self.graph.get_statistics(),
|
|
'task_queue_size': self.task_queue.qsize(),
|
|
'currently_processing_count': currently_processing_count,
|
|
'currently_processing': currently_processing_list[:5],
|
|
'tasks_in_queue': self.task_queue.qsize(),
|
|
'tasks_completed': self.indicators_completed,
|
|
'tasks_skipped': self.tasks_skipped,
|
|
'tasks_rescheduled': self.tasks_re_enqueued,
|
|
}
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return { 'status': 'error', 'message': 'Failed to get status' }
|
|
|
|
def _initialize_provider_states(self, target: str) -> None:
|
|
"""
|
|
FIXED: Safer provider state initialization with error handling.
|
|
"""
|
|
try:
|
|
if not self.graph.graph.has_node(target):
|
|
return
|
|
|
|
node_data = self.graph.graph.nodes[target]
|
|
if 'metadata' not in node_data:
|
|
node_data['metadata'] = {}
|
|
if 'provider_states' not in node_data['metadata']:
|
|
node_data['metadata']['provider_states'] = {}
|
|
except Exception as e:
|
|
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
|
|
|
|
|
|
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool, is_extracted: bool = False) -> List:
|
|
"""
|
|
FIXED: Improved provider eligibility checking with better filtering.
|
|
"""
|
|
if dns_only:
|
|
return [p for p in self.providers if p.get_name() == 'dns']
|
|
|
|
eligible = []
|
|
target_key = 'ips' if is_ip else 'domains'
|
|
|
|
# Check if the target is part of a large entity
|
|
is_in_large_entity = False
|
|
if self.graph.graph.has_node(target) and not is_extracted:
|
|
metadata = self.graph.graph.nodes[target].get('metadata', {})
|
|
if 'large_entity_id' in metadata:
|
|
is_in_large_entity = True
|
|
|
|
for provider in self.providers:
|
|
try:
|
|
# If in large entity, only allow dns and correlation providers
|
|
if is_in_large_entity and provider.get_name() not in ['dns', 'correlation']:
|
|
continue
|
|
|
|
# Check if provider supports this target type
|
|
if not provider.get_eligibility().get(target_key, False):
|
|
continue
|
|
|
|
# Check if provider is available/configured
|
|
if not provider.is_available():
|
|
continue
|
|
|
|
# Check if we already successfully queried this provider
|
|
if not self._already_queried_provider(target, provider.get_name()):
|
|
eligible.append(provider)
|
|
|
|
except Exception as e:
|
|
self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}")
|
|
continue
|
|
|
|
return eligible
|
|
|
|
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
|
|
"""
|
|
FIXED: More robust check for already queried providers with proper error handling.
|
|
"""
|
|
try:
|
|
if not self.graph.graph.has_node(target):
|
|
return False
|
|
|
|
node_data = self.graph.graph.nodes[target]
|
|
provider_states = node_data.get('metadata', {}).get('provider_states', {})
|
|
provider_state = provider_states.get(provider_name)
|
|
|
|
# Only consider it already queried if it was successful
|
|
return (provider_state is not None and
|
|
provider_state.get('status') == 'success' and
|
|
provider_state.get('results_count', 0) > 0)
|
|
except Exception as e:
|
|
self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}")
|
|
return False
|
|
|
|
def _update_provider_state(self, target: str, provider_name: str, status: str,
|
|
results_count: int, error: Optional[str], start_time: datetime) -> None:
|
|
"""
|
|
FIXED: More robust provider state updates with validation.
|
|
"""
|
|
try:
|
|
if not self.graph.graph.has_node(target):
|
|
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
|
|
return
|
|
|
|
node_data = self.graph.graph.nodes[target]
|
|
if 'metadata' not in node_data:
|
|
node_data['metadata'] = {}
|
|
if 'provider_states' not in node_data['metadata']:
|
|
node_data['metadata']['provider_states'] = {}
|
|
|
|
# Calculate duration safely
|
|
try:
|
|
duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
|
|
except Exception:
|
|
duration_ms = 0
|
|
|
|
node_data['metadata']['provider_states'][provider_name] = {
|
|
'status': status,
|
|
'timestamp': start_time.isoformat(),
|
|
'results_count': max(0, results_count), # Ensure non-negative
|
|
'error': str(error) if error else None,
|
|
'duration_ms': duration_ms
|
|
}
|
|
|
|
# Update last modified time for forensic integrity
|
|
self.last_modified = datetime.now(timezone.utc).isoformat()
|
|
|
|
except Exception as e:
|
|
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
|
|
|
|
def _log_target_processing_error(self, target: str, error: str) -> None:
|
|
self.logger.logger.error(f"Target processing failed for {target}: {error}")
|
|
|
|
def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
|
|
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
|
|
|
|
def _calculate_progress(self) -> float:
|
|
"""
|
|
Enhanced progress calculation that properly accounts for correlation tasks
|
|
added during the correlation phase.
|
|
"""
|
|
try:
|
|
if self.total_tasks_ever_enqueued == 0:
|
|
return 0.0
|
|
|
|
# Add small buffer for tasks still in queue to avoid showing 100% too early
|
|
queue_size = max(0, self.task_queue.qsize())
|
|
with self.processing_lock:
|
|
active_tasks = len(self.currently_processing)
|
|
|
|
# For correlation phase, be more conservative about progress calculation
|
|
if self.status == ScanStatus.FINALIZING:
|
|
# During correlation phase, show progress more conservatively
|
|
base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100
|
|
|
|
# If we have active correlation tasks, cap progress at 95% until done
|
|
if queue_size > 0 or active_tasks > 0:
|
|
return min(95.0, base_progress)
|
|
else:
|
|
return min(100.0, base_progress)
|
|
|
|
# Normal phase progress calculation
|
|
adjusted_total = max(self.total_tasks_ever_enqueued,
|
|
self.indicators_completed + queue_size + active_tasks)
|
|
|
|
if adjusted_total == 0:
|
|
return 100.0
|
|
|
|
progress = (self.indicators_completed / adjusted_total) * 100
|
|
return max(0.0, min(100.0, progress)) # Clamp between 0 and 100
|
|
|
|
except Exception as e:
|
|
self.logger.logger.warning(f"Error calculating progress: {e}")
|
|
return 0.0
|
|
|
|
def get_graph_data(self) -> Dict[str, Any]:
|
|
graph_data = self.graph.get_graph_data()
|
|
graph_data['initial_targets'] = list(self.initial_targets)
|
|
return graph_data
|
|
|
|
|
|
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
|
|
info = {}
|
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
|
for filename in os.listdir(provider_dir):
|
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
|
module_name = f"providers.{filename[:-3]}"
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
|
provider_class = attribute
|
|
temp_provider = provider_class(name=attribute_name, session_config=self.config)
|
|
provider_name = temp_provider.get_name()
|
|
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
|
info[provider_name] = {
|
|
'display_name': temp_provider.get_display_name(),
|
|
'requires_api_key': temp_provider.requires_api_key(),
|
|
'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
|
|
'enabled': self.config.is_provider_enabled(provider_name),
|
|
'rate_limit': self.config.get_rate_limit(provider_name),
|
|
}
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return info |