910 lines
39 KiB
Python
910 lines
39 KiB
Python
# dnsrecon-reduced/core/scanner.py
|
|
|
|
import threading
|
|
import traceback
|
|
import os
|
|
import importlib
|
|
import redis
|
|
import time
|
|
import random # Imported for jitter
|
|
from typing import List, Set, Dict, Any, Tuple, Optional
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from collections import defaultdict
|
|
from queue import PriorityQueue
|
|
from datetime import datetime, timezone
|
|
|
|
from core.graph_manager import GraphManager, NodeType
|
|
from core.logger import get_forensic_logger, new_session
|
|
from core.provider_result import ProviderResult
|
|
from utils.helpers import _is_valid_ip, _is_valid_domain
|
|
from providers.base_provider import BaseProvider
|
|
from core.rate_limiter import GlobalRateLimiter
|
|
|
|
class ScanStatus:
|
|
"""Enumeration of scan statuses."""
|
|
IDLE = "idle"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class Scanner:
|
|
"""
|
|
Main scanning orchestrator for DNSRecon passive reconnaissance.
|
|
UNIFIED: Combines comprehensive features with improved display formatting.
|
|
"""
|
|
|
|
def __init__(self, session_config=None):
|
|
"""Initialize scanner with session-specific configuration."""
|
|
try:
|
|
# Use provided session config or create default
|
|
if session_config is None:
|
|
from core.session_config import create_session_config
|
|
session_config = create_session_config()
|
|
|
|
self.config = session_config
|
|
self.graph = GraphManager()
|
|
self.providers = []
|
|
self.status = ScanStatus.IDLE
|
|
self.current_target = None
|
|
self.current_depth = 0
|
|
self.max_depth = 2
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
self.session_id: Optional[str] = None # Will be set by session manager
|
|
self.task_queue = PriorityQueue()
|
|
self.target_retries = defaultdict(int)
|
|
self.scan_failed_due_to_retries = False
|
|
self.initial_targets = set()
|
|
|
|
# Thread-safe processing tracking (from Document 1)
|
|
self.currently_processing = set()
|
|
self.processing_lock = threading.Lock()
|
|
# Display-friendly processing list (from Document 2)
|
|
self.currently_processing_display = []
|
|
|
|
# Scanning progress tracking
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.indicators_completed = 0
|
|
self.tasks_re_enqueued = 0
|
|
self.tasks_skipped = 0 # BUGFIX: Initialize tasks_skipped
|
|
self.total_tasks_ever_enqueued = 0
|
|
self.current_indicator = ""
|
|
self.last_task_from_queue = None
|
|
|
|
# Concurrent processing configuration
|
|
self.max_workers = self.config.max_concurrent_requests
|
|
self.executor = None
|
|
|
|
# Status logger thread with improved formatting
|
|
self.status_logger_thread = None
|
|
self.status_logger_stop_event = threading.Event()
|
|
|
|
# Initialize providers with session config
|
|
self._initialize_providers()
|
|
|
|
# Initialize logger
|
|
self.logger = get_forensic_logger()
|
|
|
|
# Initialize global rate limiter
|
|
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Scanner initialization failed: {e}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
def _is_stop_requested(self) -> bool:
|
|
"""
|
|
Check if stop is requested using both local and Redis-based signals.
|
|
This ensures reliable termination across process boundaries.
|
|
"""
|
|
if self.stop_event.is_set():
|
|
return True
|
|
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
return session_manager.is_stop_requested(self.session_id)
|
|
except Exception as e:
|
|
# Fall back to local event
|
|
return self.stop_event.is_set()
|
|
|
|
return False
|
|
|
|
def _set_stop_signal(self) -> None:
|
|
"""
|
|
Set stop signal both locally and in Redis.
|
|
"""
|
|
self.stop_event.set()
|
|
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
session_manager.set_stop_signal(self.session_id)
|
|
except Exception as e:
|
|
pass
|
|
|
|
def __getstate__(self):
|
|
"""Prepare object for pickling by excluding unpicklable attributes."""
|
|
state = self.__dict__.copy()
|
|
|
|
unpicklable_attrs = [
|
|
'stop_event',
|
|
'scan_thread',
|
|
'executor',
|
|
'processing_lock',
|
|
'task_queue',
|
|
'rate_limiter',
|
|
'logger',
|
|
'status_logger_thread',
|
|
'status_logger_stop_event'
|
|
]
|
|
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
|
|
if 'providers' in state:
|
|
for provider in state['providers']:
|
|
if hasattr(provider, '_stop_event'):
|
|
provider._stop_event = None
|
|
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore object after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
|
|
self.stop_event = threading.Event()
|
|
self.scan_thread = None
|
|
self.executor = None
|
|
self.processing_lock = threading.Lock()
|
|
self.task_queue = PriorityQueue()
|
|
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
|
self.logger = get_forensic_logger()
|
|
self.status_logger_thread = None
|
|
self.status_logger_stop_event = threading.Event()
|
|
|
|
if not hasattr(self, 'providers') or not self.providers:
|
|
self._initialize_providers()
|
|
|
|
if not hasattr(self, 'currently_processing'):
|
|
self.currently_processing = set()
|
|
|
|
if not hasattr(self, 'currently_processing_display'):
|
|
self.currently_processing_display = []
|
|
|
|
if hasattr(self, 'providers'):
|
|
for provider in self.providers:
|
|
if hasattr(provider, 'set_stop_event'):
|
|
provider.set_stop_event(self.stop_event)
|
|
|
|
def _initialize_providers(self) -> None:
|
|
"""Initialize all available providers based on session configuration."""
|
|
self.providers = []
|
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
|
for filename in os.listdir(provider_dir):
|
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
|
module_name = f"providers.{filename[:-3]}"
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
|
provider_class = attribute
|
|
provider = provider_class(name=attribute_name, session_config=self.config)
|
|
provider_name = provider.get_name()
|
|
|
|
if self.config.is_provider_enabled(provider_name):
|
|
if provider.is_available():
|
|
provider.set_stop_event(self.stop_event)
|
|
self.providers.append(provider)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
|
|
def _status_logger_thread(self):
|
|
"""Periodically prints a clean, formatted scan status to the terminal."""
|
|
HEADER = "\033[95m"
|
|
CYAN = "\033[96m"
|
|
GREEN = "\033[92m"
|
|
YELLOW = "\033[93m"
|
|
BLUE = "\033[94m"
|
|
ENDC = "\033[0m"
|
|
BOLD = "\033[1m"
|
|
|
|
last_status_str = ""
|
|
while not self.status_logger_stop_event.is_set():
|
|
try:
|
|
with self.processing_lock:
|
|
in_flight_tasks = list(self.currently_processing)
|
|
self.currently_processing_display = in_flight_tasks.copy()
|
|
|
|
status_str = (
|
|
f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
|
|
f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
|
|
f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
|
|
f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
|
|
f"Skipped: {self.tasks_skipped} | "
|
|
f"Rescheduled: {self.tasks_re_enqueued}"
|
|
)
|
|
|
|
if status_str != last_status_str:
|
|
print(f"\n{'-'*80}")
|
|
print(status_str)
|
|
if self.last_task_from_queue:
|
|
# Unpack the new time-based queue item
|
|
_, p, (pn, ti, d) = self.last_task_from_queue
|
|
print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}")
|
|
if in_flight_tasks:
|
|
print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}")
|
|
display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]]
|
|
print("\n".join(display_tasks))
|
|
if len(in_flight_tasks) > 3:
|
|
print(f" ... and {len(in_flight_tasks) - 3} more")
|
|
print(f"{'-'*80}")
|
|
last_status_str = status_str
|
|
except Exception:
|
|
pass
|
|
|
|
time.sleep(2)
|
|
|
|
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
|
|
"""
|
|
Starts a new reconnaissance scan.
|
|
"""
|
|
if self.scan_thread and self.scan_thread.is_alive():
|
|
self._set_stop_signal()
|
|
self.status = ScanStatus.STOPPED
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
self.task_queue = PriorityQueue()
|
|
if self.executor:
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
self.executor = None
|
|
self.scan_thread.join(5.0)
|
|
|
|
self.status = ScanStatus.IDLE
|
|
self.stop_event.clear()
|
|
|
|
if self.session_id:
|
|
from core.session_manager import session_manager
|
|
session_manager.clear_stop_signal(self.session_id)
|
|
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
self.task_queue = PriorityQueue()
|
|
self.target_retries.clear()
|
|
self.scan_failed_due_to_retries = False
|
|
self.tasks_skipped = 0
|
|
self.last_task_from_queue = None
|
|
|
|
self._update_session_state()
|
|
|
|
try:
|
|
if not hasattr(self, 'providers') or not self.providers:
|
|
return False
|
|
|
|
if clear_graph:
|
|
self.graph.clear()
|
|
self.initial_targets.clear()
|
|
|
|
if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
|
|
node_data = self.graph.graph.nodes[force_rescan_target]
|
|
if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
|
|
node_data['metadata']['provider_states'] = {}
|
|
|
|
self.current_target = target.lower().strip()
|
|
self.initial_targets.add(self.current_target)
|
|
self.max_depth = max_depth
|
|
self.current_depth = 0
|
|
|
|
self.total_indicators_found = 0
|
|
self.indicators_processed = 0
|
|
self.indicators_completed = 0
|
|
self.tasks_re_enqueued = 0
|
|
self.total_tasks_ever_enqueued = 0
|
|
self.current_indicator = self.current_target
|
|
|
|
self._update_session_state()
|
|
self.logger = new_session()
|
|
|
|
self.scan_thread = threading.Thread(
|
|
target=self._execute_scan,
|
|
args=(self.current_target, max_depth),
|
|
daemon=True
|
|
)
|
|
self.scan_thread.start()
|
|
|
|
self.status_logger_stop_event.clear()
|
|
self.status_logger_thread = threading.Thread(target=self._status_logger_thread, daemon=True)
|
|
self.status_logger_thread.start()
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self._update_session_state()
|
|
return False
|
|
|
|
def _get_priority(self, provider_name):
|
|
rate_limit = self.config.get_rate_limit(provider_name)
|
|
if rate_limit > 90:
|
|
return 1 # Highest priority
|
|
elif rate_limit > 50:
|
|
return 2
|
|
else:
|
|
return 3 # Lowest priority
|
|
|
|
def _execute_scan(self, target: str, max_depth: int) -> None:
|
|
"""
|
|
Execute the reconnaissance scan with a time-based, robust scheduler.
|
|
Handles rate-limiting via deferral and failures via exponential backoff.
|
|
"""
|
|
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
|
processed_tasks = set()
|
|
|
|
is_ip = _is_valid_ip(target)
|
|
initial_providers = self._get_eligible_providers(target, is_ip, False)
|
|
for provider in initial_providers:
|
|
provider_name = provider.get_name()
|
|
priority = self._get_priority(provider_name)
|
|
# OVERHAUL: Enqueue with current timestamp to run immediately
|
|
self.task_queue.put((time.time(), priority, (provider_name, target, 0)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
|
|
try:
|
|
self.status = ScanStatus.RUNNING
|
|
self._update_session_state()
|
|
|
|
enabled_providers = [provider.get_name() for provider in self.providers]
|
|
self.logger.log_scan_start(target, max_depth, enabled_providers)
|
|
|
|
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
|
self.graph.add_node(target, node_type)
|
|
self._initialize_provider_states(target)
|
|
|
|
while not self._is_stop_requested():
|
|
if self.task_queue.empty() and not self.currently_processing:
|
|
break # Scan is complete
|
|
|
|
try:
|
|
# OVERHAUL: Peek at the next task to see if it's ready to run
|
|
next_run_at, _, _ = self.task_queue.queue[0]
|
|
if next_run_at > time.time():
|
|
time.sleep(0.1) # Sleep to prevent busy-waiting for future tasks
|
|
continue
|
|
|
|
# Task is ready, so get it from the queue
|
|
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get()
|
|
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
|
|
|
|
except IndexError:
|
|
time.sleep(0.1) # Queue is empty, but tasks might still be processing
|
|
continue
|
|
|
|
task_tuple = (provider_name, target_item)
|
|
if task_tuple in processed_tasks:
|
|
self.tasks_skipped += 1
|
|
self.indicators_completed +=1
|
|
continue
|
|
|
|
if depth > max_depth:
|
|
continue
|
|
|
|
# OVERHAUL: Handle rate limiting with time-based deferral
|
|
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
|
|
defer_until = time.time() + 60 # Defer for 60 seconds
|
|
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
|
|
self.tasks_re_enqueued += 1
|
|
continue
|
|
|
|
with self.processing_lock:
|
|
if self._is_stop_requested(): break
|
|
self.currently_processing.add(task_tuple)
|
|
|
|
try:
|
|
self.current_depth = depth
|
|
self.current_indicator = target_item
|
|
self._update_session_state()
|
|
|
|
if self._is_stop_requested(): break
|
|
|
|
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
|
|
|
if provider:
|
|
new_targets, _, success = self._query_single_provider_for_target(provider, target_item, depth)
|
|
|
|
if self._is_stop_requested(): break
|
|
|
|
if not success:
|
|
self.target_retries[task_tuple] += 1
|
|
if self.target_retries[task_tuple] <= self.config.max_retries_per_target:
|
|
# OVERHAUL: Exponential backoff for retries
|
|
retry_count = self.target_retries[task_tuple]
|
|
backoff_delay = (2 ** retry_count) + random.uniform(0, 1) # Add jitter
|
|
retry_at = time.time() + backoff_delay
|
|
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
|
|
self.tasks_re_enqueued += 1
|
|
else:
|
|
self.scan_failed_due_to_retries = True
|
|
self._log_target_processing_error(str(task_tuple), "Max retries exceeded")
|
|
else:
|
|
processed_tasks.add(task_tuple)
|
|
self.indicators_completed += 1
|
|
|
|
if not self._is_stop_requested():
|
|
for new_target in new_targets:
|
|
is_ip_new = _is_valid_ip(new_target)
|
|
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
|
|
for p_new in eligible_providers_new:
|
|
p_name_new = p_new.get_name()
|
|
if (p_name_new, new_target) not in processed_tasks:
|
|
new_depth = depth + 1 if new_target in new_targets else depth
|
|
new_priority = self._get_priority(p_name_new)
|
|
# OVERHAUL: Enqueue new tasks to run immediately
|
|
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
finally:
|
|
with self.processing_lock:
|
|
self.currently_processing.discard(task_tuple)
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
self.status = ScanStatus.FAILED
|
|
self.logger.logger.error(f"Scan failed: {e}")
|
|
finally:
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
if self._is_stop_requested():
|
|
self.status = ScanStatus.STOPPED
|
|
elif self.scan_failed_due_to_retries:
|
|
self.status = ScanStatus.FAILED
|
|
else:
|
|
self.status = ScanStatus.COMPLETED
|
|
|
|
self.status_logger_stop_event.set()
|
|
if self.status_logger_thread:
|
|
self.status_logger_thread.join()
|
|
|
|
self._update_session_state()
|
|
self.logger.log_scan_complete()
|
|
if self.executor:
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
self.executor = None
|
|
|
|
def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
|
"""
|
|
Query a single provider and process the unified ProviderResult.
|
|
"""
|
|
if self._is_stop_requested():
|
|
return set(), set(), False
|
|
|
|
is_ip = _is_valid_ip(target)
|
|
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
|
|
|
self.graph.add_node(target, target_type)
|
|
self._initialize_provider_states(target)
|
|
|
|
new_targets = set()
|
|
large_entity_members = set()
|
|
provider_successful = True
|
|
|
|
try:
|
|
provider_result = self._query_single_provider_unified(provider, target, is_ip, depth)
|
|
|
|
if provider_result is None:
|
|
provider_successful = False
|
|
elif not self._is_stop_requested():
|
|
discovered, is_large_entity = self._process_provider_result_unified(
|
|
target, provider, provider_result, depth
|
|
)
|
|
if is_large_entity:
|
|
large_entity_members.update(discovered)
|
|
else:
|
|
new_targets.update(discovered)
|
|
self.graph.process_correlations_for_node(target)
|
|
except Exception as e:
|
|
provider_successful = False
|
|
self._log_provider_error(target, provider.get_name(), str(e))
|
|
|
|
return new_targets, large_entity_members, provider_successful
|
|
|
|
def _query_single_provider_unified(self, provider: BaseProvider, target: str, is_ip: bool, current_depth: int) -> Optional[ProviderResult]:
|
|
"""
|
|
Query a single provider with stop signal checking.
|
|
"""
|
|
provider_name = provider.get_name()
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
if self._is_stop_requested():
|
|
return None
|
|
|
|
try:
|
|
if is_ip:
|
|
result = provider.query_ip(target)
|
|
else:
|
|
result = provider.query_domain(target)
|
|
|
|
if self._is_stop_requested():
|
|
return None
|
|
|
|
relationship_count = result.get_relationship_count() if result else 0
|
|
self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
|
|
return None
|
|
|
|
def _process_provider_result_unified(self, target: str, provider: BaseProvider,
|
|
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
|
|
"""
|
|
Process a unified ProviderResult object to update the graph.
|
|
VERIFIED: Proper ISP and CA node type assignment.
|
|
"""
|
|
provider_name = provider.get_name()
|
|
discovered_targets = set()
|
|
|
|
if self._is_stop_requested():
|
|
return discovered_targets, False
|
|
|
|
# Process all attributes first, grouping by target node
|
|
attributes_by_node = defaultdict(list)
|
|
for attribute in provider_result.attributes:
|
|
attr_dict = {
|
|
"name": attribute.name,
|
|
"value": attribute.value,
|
|
"type": attribute.type,
|
|
"provider": attribute.provider,
|
|
"confidence": attribute.confidence,
|
|
"metadata": attribute.metadata
|
|
}
|
|
attributes_by_node[attribute.target_node].append(attr_dict)
|
|
|
|
# Add attributes to existing nodes (important for ISP nodes to get ASN attributes)
|
|
for node_id, node_attributes_list in attributes_by_node.items():
|
|
if self.graph.graph.has_node(node_id):
|
|
# Node already exists, just add attributes
|
|
if _is_valid_ip(node_id):
|
|
node_type = NodeType.IP
|
|
else:
|
|
node_type = NodeType.DOMAIN
|
|
|
|
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
|
|
|
|
# Check if this should be a large entity
|
|
if provider_result.get_relationship_count() > self.config.large_entity_threshold:
|
|
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
|
|
return members, True
|
|
|
|
# Process relationships and create nodes with proper types
|
|
for i, relationship in enumerate(provider_result.relationships):
|
|
if i % 5 == 0 and self._is_stop_requested():
|
|
break
|
|
|
|
source_node = relationship.source_node
|
|
target_node = relationship.target_node
|
|
|
|
# VERIFIED: Determine source node type
|
|
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
|
|
|
|
# VERIFIED: Determine target node type based on provider and relationship
|
|
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
|
|
target_type = NodeType.ISP # ISP node for Shodan organization data
|
|
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
|
|
target_type = NodeType.CA # CA node for certificate issuers
|
|
elif _is_valid_ip(target_node):
|
|
target_type = NodeType.IP
|
|
else:
|
|
target_type = NodeType.DOMAIN
|
|
|
|
# Create or update nodes with proper types
|
|
self.graph.add_node(source_node, source_type)
|
|
self.graph.add_node(target_node, target_type)
|
|
|
|
# Add the relationship edge
|
|
if self.graph.add_edge(
|
|
source_node, target_node,
|
|
relationship.relationship_type,
|
|
relationship.confidence,
|
|
provider_name,
|
|
relationship.raw_data
|
|
):
|
|
pass # Edge was successfully added
|
|
|
|
# Add target to discovered nodes for further processing
|
|
if _is_valid_domain(target_node) or _is_valid_ip(target_node):
|
|
discovered_targets.add(target_node)
|
|
|
|
return discovered_targets, False
|
|
|
|
def _create_large_entity_from_provider_result(self, source: str, provider_name: str,
|
|
provider_result: ProviderResult, current_depth: int) -> Set[str]:
|
|
"""
|
|
Create a large entity node from a ProviderResult.
|
|
"""
|
|
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
|
|
|
|
targets = [rel.target_node for rel in provider_result.relationships]
|
|
node_type = 'unknown'
|
|
|
|
if targets:
|
|
if _is_valid_domain(targets[0]):
|
|
node_type = 'domain'
|
|
elif _is_valid_ip(targets[0]):
|
|
node_type = 'ip'
|
|
|
|
for target in targets:
|
|
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
|
|
self.graph.add_node(target, target_node_type)
|
|
|
|
attributes_dict = {
|
|
'count': len(targets),
|
|
'nodes': targets,
|
|
'node_type': node_type,
|
|
'source_provider': provider_name,
|
|
'discovery_depth': current_depth,
|
|
'threshold_exceeded': self.config.large_entity_threshold,
|
|
}
|
|
|
|
attributes_list = []
|
|
for key, value in attributes_dict.items():
|
|
attributes_list.append({
|
|
"name": key,
|
|
"value": value,
|
|
"type": "large_entity_info",
|
|
"provider": provider_name,
|
|
"confidence": 0.9,
|
|
"metadata": {}
|
|
})
|
|
|
|
description = f'Large entity created due to {len(targets)} relationships from {provider_name}'
|
|
|
|
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
|
|
|
|
if provider_result.relationships:
|
|
rel_type = provider_result.relationships[0].relationship_type
|
|
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
|
|
{'large_entity_info': f'Contains {len(targets)} {node_type}s'})
|
|
|
|
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
|
|
|
|
return set(targets)
|
|
|
|
def stop_scan(self) -> bool:
|
|
"""Request immediate scan termination with proper cleanup."""
|
|
try:
|
|
self.logger.logger.info("Scan termination requested by user")
|
|
self._set_stop_signal()
|
|
self.status = ScanStatus.STOPPED
|
|
|
|
with self.processing_lock:
|
|
self.currently_processing.clear()
|
|
self.currently_processing_display = []
|
|
|
|
self.task_queue = PriorityQueue()
|
|
|
|
if self.executor:
|
|
try:
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
except Exception:
|
|
pass
|
|
|
|
self._update_session_state()
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.logger.error(f"Error during scan termination: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
|
|
"""
|
|
Extracts a node from a large entity and re-queues it for scanning.
|
|
"""
|
|
if not self.graph.graph.has_node(large_entity_id):
|
|
return False
|
|
|
|
predecessors = list(self.graph.graph.predecessors(large_entity_id))
|
|
if not predecessors:
|
|
return False
|
|
source_node_id = predecessors[0]
|
|
|
|
original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id)
|
|
if not original_edge_data:
|
|
return False
|
|
|
|
success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract)
|
|
if not success:
|
|
return False
|
|
|
|
self.graph.add_edge(
|
|
source_id=source_node_id,
|
|
target_id=node_id_to_extract,
|
|
relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'),
|
|
confidence_score=original_edge_data.get('confidence_score', 0.85),
|
|
source_provider=original_edge_data.get('source_provider', 'unknown'),
|
|
raw_data={'context': f'Extracted from large entity {large_entity_id}'}
|
|
)
|
|
|
|
is_ip = _is_valid_ip(node_id_to_extract)
|
|
|
|
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
|
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
|
|
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
|
|
|
|
eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False)
|
|
for provider in eligible_providers:
|
|
provider_name = provider.get_name()
|
|
priority = self._get_priority(provider_name)
|
|
self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth)))
|
|
self.total_tasks_ever_enqueued += 1
|
|
|
|
if self.status != ScanStatus.RUNNING:
|
|
self.status = ScanStatus.RUNNING
|
|
self._update_session_state()
|
|
|
|
if not self.scan_thread or not self.scan_thread.is_alive():
|
|
self.scan_thread = threading.Thread(
|
|
target=self._execute_scan,
|
|
args=(self.current_target, self.max_depth),
|
|
daemon=True
|
|
)
|
|
self.scan_thread.start()
|
|
|
|
return True
|
|
|
|
def _update_session_state(self) -> None:
|
|
"""
|
|
Update the scanner state in Redis for GUI updates.
|
|
"""
|
|
if self.session_id:
|
|
try:
|
|
from core.session_manager import session_manager
|
|
session_manager.update_session_scanner(self.session_id, self)
|
|
except Exception:
|
|
pass
|
|
|
|
def get_scan_status(self) -> Dict[str, Any]:
|
|
"""Get current scan status with comprehensive processing information."""
|
|
try:
|
|
with self.processing_lock:
|
|
currently_processing_count = len(self.currently_processing)
|
|
currently_processing_list = list(self.currently_processing)
|
|
|
|
return {
|
|
'status': self.status,
|
|
'target_domain': self.current_target,
|
|
'current_depth': self.current_depth,
|
|
'max_depth': self.max_depth,
|
|
'current_indicator': self.current_indicator,
|
|
'indicators_processed': self.indicators_processed,
|
|
'indicators_completed': self.indicators_completed,
|
|
'tasks_re_enqueued': self.tasks_re_enqueued,
|
|
'progress_percentage': self._calculate_progress(),
|
|
'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
|
|
'enabled_providers': [provider.get_name() for provider in self.providers],
|
|
'graph_statistics': self.graph.get_statistics(),
|
|
'task_queue_size': self.task_queue.qsize(),
|
|
'currently_processing_count': currently_processing_count,
|
|
'currently_processing': currently_processing_list[:5],
|
|
'tasks_in_queue': self.task_queue.qsize(),
|
|
'tasks_completed': self.indicators_completed,
|
|
'tasks_skipped': self.tasks_skipped,
|
|
'tasks_rescheduled': self.tasks_re_enqueued,
|
|
}
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return { 'status': 'error', 'message': 'Failed to get status' }
|
|
|
|
def _initialize_provider_states(self, target: str) -> None:
|
|
"""Initialize provider states for forensic tracking."""
|
|
if not self.graph.graph.has_node(target): return
|
|
node_data = self.graph.graph.nodes[target]
|
|
if 'metadata' not in node_data: node_data['metadata'] = {}
|
|
if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {}
|
|
|
|
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
|
|
"""Get providers eligible for querying this target."""
|
|
if dns_only:
|
|
return [p for p in self.providers if p.get_name() == 'dns']
|
|
eligible = []
|
|
target_key = 'ips' if is_ip else 'domains'
|
|
for provider in self.providers:
|
|
if provider.get_eligibility().get(target_key):
|
|
if not self._already_queried_provider(target, provider.get_name()):
|
|
eligible.append(provider)
|
|
return eligible
|
|
|
|
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
|
|
"""Check if we already successfully queried a provider for a target."""
|
|
if not self.graph.graph.has_node(target): return False
|
|
node_data = self.graph.graph.nodes[target]
|
|
provider_states = node_data.get('metadata', {}).get('provider_states', {})
|
|
provider_state = provider_states.get(provider_name)
|
|
return provider_state is not None and provider_state.get('status') == 'success'
|
|
|
|
def _update_provider_state(self, target: str, provider_name: str, status: str,
|
|
results_count: int, error: Optional[str], start_time: datetime) -> None:
|
|
"""Update provider state in node metadata for forensic tracking."""
|
|
if not self.graph.graph.has_node(target): return
|
|
node_data = self.graph.graph.nodes[target]
|
|
if 'metadata' not in node_data: node_data['metadata'] = {}
|
|
if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {}
|
|
node_data['metadata']['provider_states'][provider_name] = {
|
|
'status': status,
|
|
'timestamp': start_time.isoformat(),
|
|
'results_count': results_count,
|
|
'error': error,
|
|
'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
|
|
}
|
|
|
|
def _log_target_processing_error(self, target: str, error: str) -> None:
|
|
self.logger.logger.error(f"Target processing failed for {target}: {error}")
|
|
|
|
def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
|
|
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
|
|
|
|
def _calculate_progress(self) -> float:
|
|
if self.total_tasks_ever_enqueued == 0: return 0.0
|
|
return min(100.0, (self.indicators_completed / self.total_tasks_ever_enqueued) * 100)
|
|
|
|
def get_graph_data(self) -> Dict[str, Any]:
|
|
graph_data = self.graph.get_graph_data()
|
|
graph_data['initial_targets'] = list(self.initial_targets)
|
|
return graph_data
|
|
|
|
def export_results(self) -> Dict[str, Any]:
|
|
graph_data = self.graph.export_json()
|
|
audit_trail = self.logger.export_audit_trail()
|
|
provider_stats = {}
|
|
for provider in self.providers:
|
|
provider_stats[provider.get_name()] = provider.get_statistics()
|
|
|
|
return {
|
|
'scan_metadata': {
|
|
'target_domain': self.current_target, 'max_depth': self.max_depth,
|
|
'final_status': self.status, 'total_indicators_processed': self.indicators_processed,
|
|
'enabled_providers': list(provider_stats.keys()), 'session_id': self.session_id
|
|
},
|
|
'graph_data': graph_data,
|
|
'forensic_audit': audit_trail,
|
|
'provider_statistics': provider_stats,
|
|
'scan_summary': self.logger.get_forensic_summary()
|
|
}
|
|
|
|
def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
|
|
info = {}
|
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
|
for filename in os.listdir(provider_dir):
|
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
|
module_name = f"providers.{filename[:-3]}"
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for attribute_name in dir(module):
|
|
attribute = getattr(module, attribute_name)
|
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
|
provider_class = attribute
|
|
temp_provider = provider_class(name=attribute_name, session_config=self.config)
|
|
provider_name = temp_provider.get_name()
|
|
live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
|
info[provider_name] = {
|
|
'display_name': temp_provider.get_display_name(),
|
|
'requires_api_key': temp_provider.requires_api_key(),
|
|
'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
|
|
'enabled': self.config.is_provider_enabled(provider_name),
|
|
'rate_limit': self.config.get_rate_limit(provider_name),
|
|
}
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return info |