dnscope/core/scanner.py
2025-09-18 21:32:26 +02:00

1119 lines
49 KiB
Python

# dnsrecon-reduced/core/scanner.py
import threading
import traceback
import os
import importlib
import redis
import time
import math
import random # Imported for jitter
from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from queue import PriorityQueue
from datetime import datetime, timezone
from core.graph_manager import GraphManager, NodeType
from core.logger import get_forensic_logger, new_session
from core.provider_result import ProviderResult
from utils.helpers import _is_valid_ip, _is_valid_domain
from utils.export_manager import export_manager
from providers.base_provider import BaseProvider
from providers.correlation_provider import CorrelationProvider
from core.rate_limiter import GlobalRateLimiter
class ScanStatus:
"""Enumeration of scan statuses."""
IDLE = "idle"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
STOPPED = "stopped"
class Scanner:
"""
Main scanning orchestrator for DNSRecon passive reconnaissance.
UNIFIED: Combines comprehensive features with improved display formatting.
"""
def __init__(self, session_config=None):
"""Initialize scanner with session-specific configuration."""
try:
# Use provided session config or create default
if session_config is None:
from core.session_config import create_session_config
session_config = create_session_config()
self.config = session_config
self.graph = GraphManager()
self.providers = []
self.status = ScanStatus.IDLE
self.current_target = None
self.current_depth = 0
self.max_depth = 2
self.stop_event = threading.Event()
self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
self.initial_targets = set()
# Thread-safe processing tracking (from Document 1)
self.currently_processing = set()
self.processing_lock = threading.Lock()
# Display-friendly processing list (from Document 2)
self.currently_processing_display = []
# Scanning progress tracking
self.total_indicators_found = 0
self.indicators_processed = 0
self.indicators_completed = 0
self.tasks_re_enqueued = 0
self.tasks_skipped = 0 # BUGFIX: Initialize tasks_skipped
self.total_tasks_ever_enqueued = 0
self.current_indicator = ""
self.last_task_from_queue = None
# Concurrent processing configuration
self.max_workers = self.config.max_concurrent_requests
self.executor = None
# Status logger thread with improved formatting
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
# Initialize providers with session config
self._initialize_providers()
# Initialize logger
self.logger = get_forensic_logger()
# Initialize global rate limiter
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
except Exception as e:
print(f"ERROR: Scanner initialization failed: {e}")
traceback.print_exc()
raise
def _is_stop_requested(self) -> bool:
"""
Check if stop is requested using both local and Redis-based signals.
This ensures reliable termination across process boundaries.
"""
if self.stop_event.is_set():
return True
if self.session_id:
try:
from core.session_manager import session_manager
return session_manager.is_stop_requested(self.session_id)
except Exception as e:
# Fall back to local event
return self.stop_event.is_set()
return self.stop_event.is_set()
def _set_stop_signal(self) -> None:
"""
Set stop signal both locally and in Redis.
"""
self.stop_event.set()
if self.session_id:
try:
from core.session_manager import session_manager
session_manager.set_stop_signal(self.session_id)
except Exception as e:
pass
def __getstate__(self):
"""Prepare object for pickling by excluding unpicklable attributes."""
state = self.__dict__.copy()
unpicklable_attrs = [
'stop_event',
'scan_thread',
'executor',
'processing_lock',
'task_queue',
'rate_limiter',
'logger',
'status_logger_thread',
'status_logger_stop_event'
]
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
if 'providers' in state:
for provider in state['providers']:
if hasattr(provider, '_stop_event'):
provider._stop_event = None
return state
def __setstate__(self, state):
"""Restore object after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
self.stop_event = threading.Event()
self.scan_thread = None
self.executor = None
self.processing_lock = threading.Lock()
self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger()
self.status_logger_thread = None
self.status_logger_stop_event = threading.Event()
if not hasattr(self, 'providers') or not self.providers:
self._initialize_providers()
if not hasattr(self, 'currently_processing'):
self.currently_processing = set()
if not hasattr(self, 'currently_processing_display'):
self.currently_processing_display = []
if hasattr(self, 'providers'):
for provider in self.providers:
if hasattr(provider, 'set_stop_event'):
provider.set_stop_event(self.stop_event)
def _initialize_providers(self) -> None:
"""Initialize all available providers based on session configuration."""
self.providers = []
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
# FIXED: Pass the 'name' argument during initialization
provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name()
if self.config.is_provider_enabled(provider_name):
if provider.is_available():
provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph)
self.providers.append(provider)
except Exception as e:
traceback.print_exc()
def _status_logger_thread(self):
"""Periodically prints a clean, formatted scan status to the terminal."""
HEADER = "\033[95m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
ENDC = "\033[0m"
BOLD = "\033[1m"
last_status_str = ""
while not self.status_logger_stop_event.is_set():
try:
with self.processing_lock:
in_flight_tasks = list(self.currently_processing)
self.currently_processing_display = in_flight_tasks.copy()
status_str = (
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
f"Skipped: {self.tasks_skipped} | "
f"Rescheduled: {self.tasks_re_enqueued}"
)
if status_str != last_status_str:
print(f"\n{'-'*80}")
print(status_str)
if self.last_task_from_queue:
# Unpack the new time-based queue item
_, p, (pn, ti, d) = self.last_task_from_queue
print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}")
if in_flight_tasks:
print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}")
display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]]
print("\n".join(display_tasks))
if len(in_flight_tasks) > 3:
print(f" ... and {len(in_flight_tasks) - 3} more")
print(f"{'-'*80}")
last_status_str = status_str
except Exception:
pass
time.sleep(2)
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
if self.scan_thread and self.scan_thread.is_alive():
self.logger.logger.info("Stopping existing scan before starting new one")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# Clean up processing state
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# Clear task queue
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
# Shutdown executor
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except:
pass
finally:
self.executor = None
# Wait for scan thread to finish (with timeout)
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
self.status = ScanStatus.IDLE
self.stop_event.clear()
if self.session_id:
from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
self.target_retries.clear()
self.scan_failed_due_to_retries = False
self.tasks_skipped = 0
self.last_task_from_queue = None
self._update_session_state()
try:
if not hasattr(self, 'providers') or not self.providers:
self.logger.logger.error("No providers available for scanning")
return False
available_providers = [p for p in self.providers if p.is_available()]
if not available_providers:
self.logger.logger.error("No providers are currently available/configured")
return False
if clear_graph:
self.graph.clear()
self.initial_targets.clear()
if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
try:
node_data = self.graph.graph.nodes[force_rescan_target]
if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}")
except Exception as e:
self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}")
target = target.lower().strip()
if not target:
self.logger.logger.error("Empty target provided")
return False
from utils.helpers import is_valid_target
if not is_valid_target(target):
self.logger.logger.error(f"Invalid target format: {target}")
return False
self.current_target = target
self.initial_targets.add(self.current_target)
self.max_depth = max(1, min(5, max_depth)) # Clamp depth between 1-5
self.current_depth = 0
self.total_indicators_found = 0
self.indicators_processed = 0
self.indicators_completed = 0
self.tasks_re_enqueued = 0
self.total_tasks_ever_enqueued = 0
self.current_indicator = self.current_target
self._update_session_state()
self.logger = new_session()
try:
self.scan_thread = threading.Thread(
target=self._execute_scan,
args=(self.current_target, self.max_depth),
daemon=True,
name=f"ScanThread-{self.session_id or 'default'}"
)
self.scan_thread.start()
self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread(
target=self._status_logger_thread,
daemon=True,
name=f"StatusLogger-{self.session_id or 'default'}"
)
self.status_logger_thread.start()
self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}")
return True
except Exception as e:
self.logger.logger.error(f"Error starting scan threads: {e}")
self.status = ScanStatus.FAILED
self._update_session_state()
return False
except Exception as e:
self.logger.logger.error(f"Error in scan startup: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self._update_session_state()
return False
def _get_priority(self, provider_name):
if provider_name == 'correlation':
return 100 # Highest priority number = lowest priority (runs last)
rate_limit = self.config.get_rate_limit(provider_name)
# Handle edge cases
if rate_limit <= 0:
return 90 # Very low priority for invalid/disabled providers
if provider_name == 'dns':
return 1 # DNS is fastest, should run first
elif provider_name == 'shodan':
return 3 # Shodan is medium speed, good priority
elif provider_name == 'crtsh':
return 5 # crt.sh is slower, lower priority
else:
# For any other providers, use rate limit as a guide
if rate_limit >= 100:
return 2 # High rate limit = high priority
elif rate_limit >= 50:
return 4 # Medium-high rate limit = medium-high priority
elif rate_limit >= 20:
return 6 # Medium rate limit = medium priority
elif rate_limit >= 5:
return 8 # Low rate limit = low priority
else:
return 10 # Very low rate limit = very low priority
def _execute_scan(self, target: str, max_depth: int) -> None:
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False)
for provider in initial_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
self.task_queue.put((time.time(), priority, (provider_name, target, 0)))
self.total_tasks_ever_enqueued += 1
try:
self.status = ScanStatus.RUNNING
self._update_session_state()
enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target, max_depth, enabled_providers)
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, node_type)
self._initialize_provider_states(target)
consecutive_empty_iterations = 0
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
while not self._is_stop_requested():
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break # Scan is complete
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
# FIXED: Safe task retrieval without race conditions
try:
# Use timeout to avoid blocking indefinitely
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# FIXED: Check if task is ready to run
current_time = time.time()
if run_at > current_time:
# Task is not ready yet, re-queue it and continue
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time
continue
except: # Queue is empty or timeout occurred
time.sleep(0.1)
continue
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
# FIXED: Include depth in processed tasks to avoid incorrect skipping
task_tuple = (provider_name, target_item, depth)
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# FIXED: Proper depth checking
if depth > max_depth:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# FIXED: Rate limiting with proper time-based deferral
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
defer_until = time.time() + 60 # Defer for 60 seconds
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
continue
# FIXED: Thread-safe processing state management
with self.processing_lock:
if self._is_stop_requested():
break
# Use provider+target (without depth) for duplicate processing check
processing_key = (provider_name, target_item)
if processing_key in self.currently_processing:
# Already processing this provider+target combination, skip
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
try:
self.current_depth = depth
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested():
break
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider:
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
if self._is_stop_requested():
break
if not success:
# FIXED: Use depth-aware retry key
retry_key = (provider_name, target_item, depth)
self.target_retries[retry_key] += 1
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
# FIXED: Exponential backoff with jitter for retries
retry_count = self.target_retries[retry_key]
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes
retry_at = time.time() + backoff_delay
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
self.logger.logger.debug(f"Retrying {provider_name}:{target_item} in {backoff_delay:.1f}s (attempt {retry_count})")
else:
self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded")
else:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
# FIXED: Enqueue new targets with proper depth tracking
if not self._is_stop_requested():
for new_target in new_targets:
is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
for p_new in eligible_providers_new:
p_name_new = p_new.get_name()
new_depth = depth + 1 # Always increment depth for discovered targets
new_task_tuple = (p_name_new, new_target, new_depth)
# FIXED: Don't re-enqueue already processed tasks
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
new_priority = self._get_priority(p_name_new)
# Enqueue new tasks to run immediately
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1
else:
self.logger.logger.warning(f"Provider {provider_name} not found in active providers")
self.tasks_skipped += 1
self.indicators_completed += 1
finally:
# FIXED: Always clean up processing state
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
except Exception as e:
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
# FIXED: Comprehensive cleanup
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# FIXED: Clear any remaining tasks from queue to prevent memory leaks
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
if self._is_stop_requested():
self.status = ScanStatus.STOPPED
elif self.scan_failed_due_to_retries:
self.status = ScanStatus.FAILED
else:
self.status = ScanStatus.COMPLETED
self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0) # Don't wait forever
self._update_session_state()
self.logger.log_scan_complete()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception as e:
self.logger.logger.warning(f"Error shutting down executor: {e}")
finally:
self.executor = None
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
"""
Manages the entire process for a given target and provider.
It uses the "worker" function to get the data and then manages the consequences.
"""
if self._is_stop_requested():
return set(), set(), False
is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, target_type)
self._initialize_provider_states(target)
new_targets = set()
large_entity_members = set()
provider_successful = True
try:
provider_result = self._execute_provider_query(provider, target, is_ip)
if provider_result is None:
provider_successful = False
elif not self._is_stop_requested():
discovered, is_large_entity = self._process_provider_result_unified(
target, provider, provider_result, depth
)
if is_large_entity:
large_entity_members.update(discovered)
else:
new_targets.update(discovered)
# After processing a provider, queue a correlation task for the target
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if correlation_provider and not isinstance(provider, CorrelationProvider):
priority = self._get_priority(correlation_provider.get_name())
self.task_queue.put((time.time(), priority, (correlation_provider.get_name(), target, depth)))
# FIXED: Increment total tasks when a correlation task is enqueued
self.total_tasks_ever_enqueued += 1
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
return new_targets, large_entity_members, provider_successful
def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
"""
The "worker" function that directly communicates with the provider to fetch data.
"""
provider_name = provider.get_name()
start_time = datetime.now(timezone.utc)
if self._is_stop_requested():
return None
try:
if is_ip:
result = provider.query_ip(target)
else:
result = provider.query_domain(target)
if self._is_stop_requested():
return None
relationship_count = result.get_relationship_count() if result else 0
self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time)
return result
except Exception as e:
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
return None
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
"""
Process a unified ProviderResult object to update the graph.
VERIFIED: Proper ISP and CA node type assignment.
"""
provider_name = provider.get_name()
discovered_targets = set()
if self._is_stop_requested():
return discovered_targets, False
# Check if this should be a large entity
if provider_result.get_relationship_count() > self.config.large_entity_threshold:
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
return members, True
# Process relationships and create nodes with proper types
for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested():
break
source_node = relationship.source_node
target_node = relationship.target_node
# VERIFIED: Determine source node type
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
# VERIFIED: Determine target node type based on provider and relationship
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
target_type = NodeType.ISP # ISP node for Shodan organization data
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
target_type = NodeType.CA # CA node for certificate issuers
elif provider_name == 'correlation':
target_type = NodeType.CORRELATION_OBJECT
elif _is_valid_ip(target_node):
target_type = NodeType.IP
else:
target_type = NodeType.DOMAIN
# Add max_depth_reached flag
max_depth_reached = current_depth >= self.max_depth
# Create or update nodes with proper types
self.graph.add_node(source_node, source_type)
self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached})
# Add the relationship edge
if self.graph.add_edge(
source_node, target_node,
relationship.relationship_type,
relationship.confidence,
provider_name,
relationship.raw_data
):
pass # Edge was successfully added
# Add target to discovered nodes for further processing
if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached:
discovered_targets.add(target_node)
# Process all attributes, grouping by target node
attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes:
attr_dict = {
"name": attribute.name,
"value": attribute.value,
"type": attribute.type,
"provider": attribute.provider,
"confidence": attribute.confidence,
"metadata": attribute.metadata
}
attributes_by_node[attribute.target_node].append(attr_dict)
# Add attributes to existing nodes OR create new nodes if they don't exist
for node_id, node_attributes_list in attributes_by_node.items():
if not self.graph.graph.has_node(node_id):
# If the node doesn't exist, create it with a default type
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
else:
# If the node already exists, just add the attributes
node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain')
self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list)
return discovered_targets, False
def _create_large_entity_from_provider_result(self, source: str, provider_name: str,
provider_result: ProviderResult, current_depth: int) -> Set[str]:
"""
Create a large entity node from a ProviderResult.
"""
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
targets = [rel.target_node for rel in provider_result.relationships]
node_type = 'unknown'
if targets:
if _is_valid_domain(targets[0]):
node_type = 'domain'
elif _is_valid_ip(targets[0]):
node_type = 'ip'
for target in targets:
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
self.graph.add_node(target, target_node_type)
attributes_dict = {
'count': len(targets),
'nodes': targets,
'node_type': node_type,
'source_provider': provider_name,
'discovery_depth': current_depth,
'threshold_exceeded': self.config.large_entity_threshold,
}
attributes_list = []
for key, value in attributes_dict.items():
attributes_list.append({
"name": key,
"value": value,
"type": "large_entity_info",
"provider": provider_name,
"confidence": 0.9,
"metadata": {}
})
description = f'Large entity created due to {len(targets)} relationships from {provider_name}'
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
if provider_result.relationships:
rel_type = provider_result.relationships[0].relationship_type
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
{'large_entity_info': f'Contains {len(targets)} {node_type}s'})
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
return set(targets)
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
try:
self.logger.logger.info("Scan termination requested by user")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
self._update_session_state()
return True
except Exception as e:
self.logger.logger.error(f"Error during scan termination: {e}")
traceback.print_exc()
return False
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
"""
Extracts a node from a large entity and re-queues it for scanning.
"""
if not self.graph.graph.has_node(large_entity_id):
return False
predecessors = list(self.graph.graph.predecessors(large_entity_id))
if not predecessors:
return False
source_node_id = predecessors[0]
original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id)
if not original_edge_data:
return False
success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract)
if not success:
return False
self.graph.add_edge(
source_id=source_node_id,
target_id=node_id_to_extract,
relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'),
confidence_score=original_edge_data.get('confidence_score', 0.85),
source_provider=original_edge_data.get('source_provider', 'unknown'),
raw_data={'context': f'Extracted from large entity {large_entity_id}'}
)
is_ip = _is_valid_ip(node_id_to_extract)
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False)
for provider in eligible_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth)))
self.total_tasks_ever_enqueued += 1
if self.status != ScanStatus.RUNNING:
self.status = ScanStatus.RUNNING
self._update_session_state()
if not self.scan_thread or not self.scan_thread.is_alive():
self.scan_thread = threading.Thread(
target=self._execute_scan,
args=(self.current_target, self.max_depth),
daemon=True
)
self.scan_thread.start()
return True
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.
"""
if self.session_id:
try:
from core.session_manager import session_manager
session_manager.update_session_scanner(self.session_id, self)
except Exception:
pass
def get_scan_status(self) -> Dict[str, Any]:
"""Get current scan status with comprehensive processing information."""
try:
with self.processing_lock:
currently_processing_count = len(self.currently_processing)
currently_processing_list = list(self.currently_processing)
return {
'status': self.status,
'target_domain': self.current_target,
'current_depth': self.current_depth,
'max_depth': self.max_depth,
'current_indicator': self.current_indicator,
'indicators_processed': self.indicators_processed,
'indicators_completed': self.indicators_completed,
'tasks_re_enqueued': self.tasks_re_enqueued,
'progress_percentage': self._calculate_progress(),
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics(),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count,
'currently_processing': currently_processing_list[:5],
'tasks_in_queue': self.task_queue.qsize(),
'tasks_completed': self.indicators_completed,
'tasks_skipped': self.tasks_skipped,
'tasks_rescheduled': self.tasks_re_enqueued,
}
except Exception:
traceback.print_exc()
return { 'status': 'error', 'message': 'Failed to get status' }
def _initialize_provider_states(self, target: str) -> None:
"""
FIXED: Safer provider state initialization with error handling.
"""
try:
if not self.graph.graph.has_node(target):
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
except Exception as e:
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
"""
FIXED: Improved provider eligibility checking with better filtering.
"""
if dns_only:
return [p for p in self.providers if p.get_name() == 'dns']
eligible = []
target_key = 'ips' if is_ip else 'domains'
for provider in self.providers:
try:
# Check if provider supports this target type
if not provider.get_eligibility().get(target_key, False):
continue
# Check if provider is available/configured
if not provider.is_available():
continue
# Check if we already successfully queried this provider
if not self._already_queried_provider(target, provider.get_name()):
eligible.append(provider)
except Exception as e:
self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}")
continue
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""
FIXED: More robust check for already queried providers with proper error handling.
"""
try:
if not self.graph.graph.has_node(target):
return False
node_data = self.graph.graph.nodes[target]
provider_states = node_data.get('metadata', {}).get('provider_states', {})
provider_state = provider_states.get(provider_name)
# Only consider it already queried if it was successful
return (provider_state is not None and
provider_state.get('status') == 'success' and
provider_state.get('results_count', 0) > 0)
except Exception as e:
self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}")
return False
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None:
"""
FIXED: More robust provider state updates with validation.
"""
try:
if not self.graph.graph.has_node(target):
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
# Calculate duration safely
try:
duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
except Exception:
duration_ms = 0
node_data['metadata']['provider_states'][provider_name] = {
'status': status,
'timestamp': start_time.isoformat(),
'results_count': max(0, results_count), # Ensure non-negative
'error': str(error) if error else None,
'duration_ms': duration_ms
}
# Update last modified time for forensic integrity
self.last_modified = datetime.now(timezone.utc).isoformat()
except Exception as e:
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
def _log_target_processing_error(self, target: str, error: str) -> None:
self.logger.logger.error(f"Target processing failed for {target}: {error}")
def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _calculate_progress(self) -> float:
try:
if self.total_tasks_ever_enqueued == 0:
return 0.0
# Add small buffer for tasks still in queue to avoid showing 100% too early
queue_size = max(0, self.task_queue.qsize())
with self.processing_lock:
active_tasks = len(self.currently_processing)
# Adjust total to account for remaining work
adjusted_total = max(self.total_tasks_ever_enqueued,
self.indicators_completed + queue_size + active_tasks)
if adjusted_total == 0:
return 100.0
progress = (self.indicators_completed / adjusted_total) * 100
return max(0.0, min(100.0, progress)) # Clamp between 0 and 100
except Exception as e:
self.logger.logger.warning(f"Error calculating progress: {e}")
return 0.0
def get_graph_data(self) -> Dict[str, Any]:
graph_data = self.graph.get_graph_data()
graph_data['initial_targets'] = list(self.initial_targets)
return graph_data
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
info = {}
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
temp_provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = temp_provider.get_name()
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
info[provider_name] = {
'display_name': temp_provider.get_display_name(),
'requires_api_key': temp_provider.requires_api_key(),
'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
'enabled': self.config.is_provider_enabled(provider_name),
'rate_limit': self.config.get_rate_limit(provider_name),
}
except Exception:
traceback.print_exc()
return info