1309 lines
		
	
	
		
			59 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1309 lines
		
	
	
		
			59 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# DNScope-reduced/core/scanner.py
 | 
						|
 | 
						|
import threading
 | 
						|
import traceback
 | 
						|
import os
 | 
						|
import importlib
 | 
						|
import redis
 | 
						|
import time
 | 
						|
import math
 | 
						|
import random # Imported for jitter
 | 
						|
from typing import List, Set, Dict, Any, Tuple, Optional
 | 
						|
from concurrent.futures import ThreadPoolExecutor
 | 
						|
from collections import defaultdict
 | 
						|
from queue import PriorityQueue
 | 
						|
from datetime import datetime, timezone
 | 
						|
 | 
						|
from core.graph_manager import GraphManager, NodeType
 | 
						|
from core.logger import get_forensic_logger, new_session
 | 
						|
from core.provider_result import ProviderResult
 | 
						|
from utils.helpers import _is_valid_ip, _is_valid_domain
 | 
						|
from utils.export_manager import export_manager
 | 
						|
from providers.base_provider import BaseProvider
 | 
						|
from providers.correlation_provider import CorrelationProvider
 | 
						|
from core.rate_limiter import GlobalRateLimiter
 | 
						|
 | 
						|
class ScanStatus:
 | 
						|
    """Enumeration of scan statuses."""
 | 
						|
    IDLE = "idle"
 | 
						|
    RUNNING = "running"
 | 
						|
    FINALIZING = "finalizing"  # New state for post-scan analysis
 | 
						|
    COMPLETED = "completed"
 | 
						|
    FAILED = "failed"
 | 
						|
    STOPPED = "stopped"
 | 
						|
 | 
						|
 | 
						|
class Scanner:
 | 
						|
    """
 | 
						|
    Main scanning orchestrator for DNScope passive reconnaissance.
 | 
						|
    UNIFIED: Combines comprehensive features with improved display formatting.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, session_config=None):
 | 
						|
        """Initialize scanner with session-specific configuration."""
 | 
						|
        try:
 | 
						|
            # Use provided session config or create default
 | 
						|
            if session_config is None:
 | 
						|
                from core.session_config import create_session_config
 | 
						|
                session_config = create_session_config()
 | 
						|
            
 | 
						|
            self.config = session_config
 | 
						|
            self.graph = GraphManager()
 | 
						|
            self.providers = []
 | 
						|
            self.status = ScanStatus.IDLE
 | 
						|
            self.current_target = None
 | 
						|
            self.current_depth = 0
 | 
						|
            self.max_depth = 2
 | 
						|
            self.stop_event = threading.Event()
 | 
						|
            self.scan_thread = None
 | 
						|
            self.session_id: Optional[str] = None  # Will be set by session manager
 | 
						|
            self.task_queue = PriorityQueue()
 | 
						|
            self.target_retries = defaultdict(int)
 | 
						|
            self.scan_failed_due_to_retries = False
 | 
						|
            self.initial_targets = set()
 | 
						|
 | 
						|
            # Thread-safe processing tracking (from Document 1)
 | 
						|
            self.currently_processing = set()
 | 
						|
            self.processing_lock = threading.Lock()
 | 
						|
            # Display-friendly processing list (from Document 2)
 | 
						|
            self.currently_processing_display = []
 | 
						|
 | 
						|
            # Scanning progress tracking
 | 
						|
            self.total_indicators_found = 0
 | 
						|
            self.indicators_processed = 0
 | 
						|
            self.indicators_completed = 0
 | 
						|
            self.tasks_re_enqueued = 0
 | 
						|
            self.tasks_skipped = 0  # BUGFIX: Initialize tasks_skipped
 | 
						|
            self.total_tasks_ever_enqueued = 0
 | 
						|
            self.current_indicator = ""
 | 
						|
            self.last_task_from_queue = None
 | 
						|
 | 
						|
            # Concurrent processing configuration
 | 
						|
            self.max_workers = self.config.max_concurrent_requests
 | 
						|
            self.executor = None
 | 
						|
 | 
						|
            # Status logger thread with improved formatting
 | 
						|
            self.status_logger_thread = None
 | 
						|
            self.status_logger_stop_event = threading.Event()
 | 
						|
 | 
						|
            # Initialize providers with session config
 | 
						|
            self._initialize_providers()
 | 
						|
 | 
						|
            # Initialize logger
 | 
						|
            self.logger = get_forensic_logger()
 | 
						|
            
 | 
						|
            # Initialize global rate limiter
 | 
						|
            self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            print(f"ERROR: Scanner initialization failed: {e}")
 | 
						|
            traceback.print_exc()
 | 
						|
            raise
 | 
						|
 | 
						|
    def _is_stop_requested(self) -> bool:
 | 
						|
        """
 | 
						|
        Check if stop is requested using both local and Redis-based signals.
 | 
						|
        This ensures reliable termination across process boundaries.
 | 
						|
        """
 | 
						|
        if self.stop_event.is_set():
 | 
						|
            return True
 | 
						|
        
 | 
						|
        if self.session_id:
 | 
						|
            try:
 | 
						|
                from core.session_manager import session_manager
 | 
						|
                return session_manager.is_stop_requested(self.session_id)
 | 
						|
            except Exception as e:
 | 
						|
                # Fall back to local event
 | 
						|
                return self.stop_event.is_set()
 | 
						|
        
 | 
						|
        return self.stop_event.is_set()
 | 
						|
 | 
						|
    def _set_stop_signal(self) -> None:
 | 
						|
        """
 | 
						|
        Set stop signal both locally and in Redis.
 | 
						|
        """
 | 
						|
        self.stop_event.set()
 | 
						|
        
 | 
						|
        if self.session_id:
 | 
						|
            try:
 | 
						|
                from core.session_manager import session_manager
 | 
						|
                session_manager.set_stop_signal(self.session_id)
 | 
						|
            except Exception as e:
 | 
						|
                pass
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        """Prepare object for pickling by excluding unpicklable attributes."""
 | 
						|
        state = self.__dict__.copy()
 | 
						|
        
 | 
						|
        unpicklable_attrs = [
 | 
						|
            'stop_event',
 | 
						|
            'scan_thread', 
 | 
						|
            'executor',
 | 
						|
            'processing_lock',
 | 
						|
            'task_queue',
 | 
						|
            'rate_limiter',
 | 
						|
            'logger',
 | 
						|
            'status_logger_thread',
 | 
						|
            'status_logger_stop_event'
 | 
						|
        ]
 | 
						|
        
 | 
						|
        for attr in unpicklable_attrs:
 | 
						|
            if attr in state:
 | 
						|
                del state[attr]
 | 
						|
        
 | 
						|
        if 'providers' in state:
 | 
						|
            for provider in state['providers']:
 | 
						|
                if hasattr(provider, '_stop_event'):
 | 
						|
                    provider._stop_event = None
 | 
						|
        
 | 
						|
        return state
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        """Restore object after unpickling by reconstructing threading objects."""
 | 
						|
        self.__dict__.update(state)
 | 
						|
        
 | 
						|
        self.stop_event = threading.Event()
 | 
						|
        self.scan_thread = None
 | 
						|
        self.executor = None
 | 
						|
        self.processing_lock = threading.Lock()
 | 
						|
        self.task_queue = PriorityQueue()
 | 
						|
        self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
 | 
						|
        self.logger = get_forensic_logger()
 | 
						|
        self.status_logger_thread = None
 | 
						|
        self.status_logger_stop_event = threading.Event()
 | 
						|
        
 | 
						|
        if not hasattr(self, 'providers') or not self.providers:
 | 
						|
            self._initialize_providers()
 | 
						|
        
 | 
						|
        if not hasattr(self, 'currently_processing'):
 | 
						|
            self.currently_processing = set()
 | 
						|
        
 | 
						|
        if not hasattr(self, 'currently_processing_display'):
 | 
						|
            self.currently_processing_display = []
 | 
						|
        
 | 
						|
        if hasattr(self, 'providers'):
 | 
						|
            for provider in self.providers:
 | 
						|
                if hasattr(provider, 'set_stop_event'):
 | 
						|
                    provider.set_stop_event(self.stop_event)
 | 
						|
 | 
						|
    def _initialize_providers(self) -> None:
 | 
						|
        """Initialize all available providers based on session configuration."""
 | 
						|
        self.providers = []
 | 
						|
        provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
 | 
						|
        
 | 
						|
        print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
 | 
						|
        
 | 
						|
        correlation_provider_instance = None
 | 
						|
        
 | 
						|
        for filename in os.listdir(provider_dir):
 | 
						|
            if filename.endswith('_provider.py') and not filename.startswith('base'):
 | 
						|
                module_name = f"providers.{filename[:-3]}"
 | 
						|
                try:
 | 
						|
                    print(f"Loading provider module: {module_name}")
 | 
						|
                    module = importlib.import_module(module_name)
 | 
						|
                    for attribute_name in dir(module):
 | 
						|
                        attribute = getattr(module, attribute_name)
 | 
						|
                        if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
 | 
						|
                            provider_class = attribute
 | 
						|
                            provider = provider_class(name=attribute_name, session_config=self.config)
 | 
						|
                            provider_name = provider.get_name()
 | 
						|
 | 
						|
                            print(f"  Provider: {provider_name}")
 | 
						|
                            print(f"    Class: {provider_class.__name__}")
 | 
						|
                            print(f"    Config enabled: {self.config.is_provider_enabled(provider_name)}")
 | 
						|
                            print(f"    Requires API key: {provider.requires_api_key()}")
 | 
						|
                            
 | 
						|
                            if provider.requires_api_key():
 | 
						|
                                api_key = self.config.get_api_key(provider_name)
 | 
						|
                                print(f"    API key present: {'Yes' if api_key else 'No'}")
 | 
						|
                                if api_key:
 | 
						|
                                    print(f"    API key preview: {api_key[:8]}...")
 | 
						|
 | 
						|
                            if self.config.is_provider_enabled(provider_name):
 | 
						|
                                is_available = provider.is_available()
 | 
						|
                                print(f"    Available: {is_available}")
 | 
						|
                                
 | 
						|
                                if is_available:
 | 
						|
                                    provider.set_stop_event(self.stop_event)
 | 
						|
                                    
 | 
						|
                                    # Special handling for correlation provider
 | 
						|
                                    if isinstance(provider, CorrelationProvider):
 | 
						|
                                        provider.set_graph_manager(self.graph)
 | 
						|
                                        correlation_provider_instance = provider
 | 
						|
                                        print(f"    ✓ Correlation provider configured with graph manager")
 | 
						|
                                    
 | 
						|
                                    self.providers.append(provider)
 | 
						|
                                    print(f"    ✓ Added to scanner")
 | 
						|
                                else:
 | 
						|
                                    print(f"    ✗ Not available - skipped")
 | 
						|
                            else:
 | 
						|
                                print(f"    ✗ Disabled in config - skipped")
 | 
						|
                                
 | 
						|
                except Exception as e:
 | 
						|
                    print(f"  ERROR loading {module_name}: {e}")
 | 
						|
                    traceback.print_exc()
 | 
						|
 | 
						|
        print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
 | 
						|
        print(f"Active providers: {[p.get_name() for p in self.providers]}")
 | 
						|
        print(f"Provider count: {len(self.providers)}")
 | 
						|
        
 | 
						|
        # Verify correlation provider is properly configured
 | 
						|
        if correlation_provider_instance:
 | 
						|
            print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}")
 | 
						|
        
 | 
						|
        print("=" * 50)
 | 
						|
 | 
						|
    def _status_logger_thread(self):
 | 
						|
        """Periodically prints a clean, formatted scan status to the terminal."""
 | 
						|
        HEADER = "\033[95m"
 | 
						|
        CYAN = "\033[96m"
 | 
						|
        GREEN = "\033[92m"
 | 
						|
        YELLOW = "\033[93m"
 | 
						|
        BLUE = "\033[94m"
 | 
						|
        ENDC = "\033[0m"
 | 
						|
        BOLD = "\033[1m"
 | 
						|
        
 | 
						|
        last_status_str = ""
 | 
						|
        while not self.status_logger_stop_event.is_set():
 | 
						|
            try:
 | 
						|
                with self.processing_lock:
 | 
						|
                    in_flight_tasks = list(self.currently_processing)
 | 
						|
                    self.currently_processing_display = in_flight_tasks.copy()
 | 
						|
 | 
						|
                status_str = (
 | 
						|
                    f"{BOLD}{HEADER}Scan Status: {self.status.upper()}{ENDC} | "
 | 
						|
                    f"{CYAN}Queued: {self.task_queue.qsize()}{ENDC} | "
 | 
						|
                    f"{YELLOW}In-Flight: {len(in_flight_tasks)}{ENDC} | "
 | 
						|
                    f"{GREEN}Completed: {self.indicators_completed}{ENDC} | "
 | 
						|
                    f"Skipped: {self.tasks_skipped} | "
 | 
						|
                    f"Rescheduled: {self.tasks_re_enqueued}"
 | 
						|
                )
 | 
						|
 | 
						|
                if status_str != last_status_str:
 | 
						|
                    print(f"\n{'-'*80}")
 | 
						|
                    print(status_str)
 | 
						|
                    if self.last_task_from_queue:
 | 
						|
                        # Unpack the new time-based queue item
 | 
						|
                        _, p, (pn, ti, d) = self.last_task_from_queue
 | 
						|
                        print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}")
 | 
						|
                    if in_flight_tasks:
 | 
						|
                        print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}")
 | 
						|
                        display_tasks = [f"  - {p}: {t}" for p, t in in_flight_tasks[:3]]
 | 
						|
                        print("\n".join(display_tasks))
 | 
						|
                        if len(in_flight_tasks) > 3:
 | 
						|
                            print(f"  ... and {len(in_flight_tasks) - 3} more")
 | 
						|
                    print(f"{'-'*80}")
 | 
						|
                    last_status_str = status_str
 | 
						|
            except Exception:
 | 
						|
                pass
 | 
						|
            
 | 
						|
            time.sleep(2)
 | 
						|
 | 
						|
    def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
 | 
						|
        if self.scan_thread and self.scan_thread.is_alive():
 | 
						|
            self.logger.logger.info("Stopping existing scan before starting new one")
 | 
						|
            self._set_stop_signal()
 | 
						|
            self.status = ScanStatus.STOPPED
 | 
						|
            
 | 
						|
            # Clean up processing state
 | 
						|
            with self.processing_lock:
 | 
						|
                self.currently_processing.clear()
 | 
						|
            self.currently_processing_display = []
 | 
						|
            
 | 
						|
            # Clear task queue
 | 
						|
            while not self.task_queue.empty():
 | 
						|
                try:
 | 
						|
                    self.task_queue.get_nowait()
 | 
						|
                except:
 | 
						|
                    break
 | 
						|
            
 | 
						|
            # Shutdown executor
 | 
						|
            if self.executor:
 | 
						|
                try:
 | 
						|
                    self.executor.shutdown(wait=False, cancel_futures=True)
 | 
						|
                except:
 | 
						|
                    pass
 | 
						|
                finally:
 | 
						|
                    self.executor = None
 | 
						|
            
 | 
						|
            # Wait for scan thread to finish (with timeout)
 | 
						|
            self.scan_thread.join(timeout=5.0)
 | 
						|
            if self.scan_thread.is_alive():
 | 
						|
                self.logger.logger.warning("Previous scan thread did not terminate cleanly")
 | 
						|
 | 
						|
        self.status = ScanStatus.IDLE
 | 
						|
        self.stop_event.clear()
 | 
						|
        
 | 
						|
        if self.session_id:
 | 
						|
            from core.session_manager import session_manager
 | 
						|
            session_manager.clear_stop_signal(self.session_id)
 | 
						|
        
 | 
						|
        with self.processing_lock:
 | 
						|
            self.currently_processing.clear()
 | 
						|
        self.currently_processing_display = []
 | 
						|
        
 | 
						|
        self.task_queue = PriorityQueue()
 | 
						|
        self.target_retries.clear()
 | 
						|
        self.scan_failed_due_to_retries = False
 | 
						|
        self.tasks_skipped = 0 
 | 
						|
        self.last_task_from_queue = None
 | 
						|
        
 | 
						|
        self._update_session_state()
 | 
						|
 | 
						|
        try:
 | 
						|
            if not hasattr(self, 'providers') or not self.providers:
 | 
						|
                self.logger.logger.error("No providers available for scanning")
 | 
						|
                return False
 | 
						|
            
 | 
						|
            available_providers = [p for p in self.providers if p.is_available()]
 | 
						|
            if not available_providers:
 | 
						|
                self.logger.logger.error("No providers are currently available/configured")
 | 
						|
                return False
 | 
						|
                
 | 
						|
            if clear_graph:
 | 
						|
                self.graph.clear()
 | 
						|
                self.initial_targets.clear()
 | 
						|
 | 
						|
            if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
 | 
						|
                try:
 | 
						|
                    node_data = self.graph.graph.nodes[force_rescan_target]
 | 
						|
                    if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
 | 
						|
                        node_data['metadata']['provider_states'] = {}
 | 
						|
                        self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}")
 | 
						|
                except Exception as e:
 | 
						|
                    self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}")
 | 
						|
 | 
						|
            target = target.lower().strip()
 | 
						|
            if not target:
 | 
						|
                self.logger.logger.error("Empty target provided")
 | 
						|
                return False
 | 
						|
                
 | 
						|
            from utils.helpers import is_valid_target
 | 
						|
            if not is_valid_target(target):
 | 
						|
                self.logger.logger.error(f"Invalid target format: {target}")
 | 
						|
                return False
 | 
						|
 | 
						|
            self.current_target = target
 | 
						|
            self.initial_targets.add(self.current_target)
 | 
						|
            self.max_depth = max(1, min(5, max_depth))  # Clamp depth between 1-5
 | 
						|
            self.current_depth = 0
 | 
						|
            
 | 
						|
            self.total_indicators_found = 0
 | 
						|
            self.indicators_processed = 0
 | 
						|
            self.indicators_completed = 0
 | 
						|
            self.tasks_re_enqueued = 0
 | 
						|
            self.total_tasks_ever_enqueued = 0
 | 
						|
            self.current_indicator = self.current_target
 | 
						|
 | 
						|
            self._update_session_state()
 | 
						|
            self.logger = new_session()
 | 
						|
 | 
						|
            try:
 | 
						|
                self.scan_thread = threading.Thread(
 | 
						|
                    target=self._execute_scan,
 | 
						|
                    args=(self.current_target, self.max_depth),
 | 
						|
                    daemon=True,
 | 
						|
                    name=f"ScanThread-{self.session_id or 'default'}"
 | 
						|
                )
 | 
						|
                self.scan_thread.start()
 | 
						|
 | 
						|
                self.status_logger_stop_event.clear()
 | 
						|
                self.status_logger_thread = threading.Thread(
 | 
						|
                    target=self._status_logger_thread, 
 | 
						|
                    daemon=True,
 | 
						|
                    name=f"StatusLogger-{self.session_id or 'default'}"
 | 
						|
                )
 | 
						|
                self.status_logger_thread.start()
 | 
						|
 | 
						|
                self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}")
 | 
						|
                return True
 | 
						|
                
 | 
						|
            except Exception as e:
 | 
						|
                self.logger.logger.error(f"Error starting scan threads: {e}")
 | 
						|
                self.status = ScanStatus.FAILED
 | 
						|
                self._update_session_state()
 | 
						|
                return False
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.error(f"Error in scan startup: {e}")
 | 
						|
            traceback.print_exc()
 | 
						|
            self.status = ScanStatus.FAILED
 | 
						|
            self._update_session_state() 
 | 
						|
            return False
 | 
						|
 | 
						|
    def _get_priority(self, provider_name):
 | 
						|
        if provider_name == 'correlation':
 | 
						|
            return 100  # Highest priority number = lowest priority (runs last)
 | 
						|
        
 | 
						|
        rate_limit = self.config.get_rate_limit(provider_name)
 | 
						|
 | 
						|
        # Handle edge cases
 | 
						|
        if rate_limit <= 0:
 | 
						|
            return 90  # Very low priority for invalid/disabled providers
 | 
						|
        
 | 
						|
        if provider_name == 'dns':
 | 
						|
            return 1  # DNS is fastest, should run first
 | 
						|
        elif provider_name == 'shodan':
 | 
						|
            return 3  # Shodan is medium speed, good priority
 | 
						|
        elif provider_name == 'crtsh':
 | 
						|
            return 5  # crt.sh is slower, lower priority
 | 
						|
        else:
 | 
						|
            # For any other providers, use rate limit as a guide
 | 
						|
            if rate_limit >= 100:
 | 
						|
                return 2  # High rate limit = high priority
 | 
						|
            elif rate_limit >= 50:
 | 
						|
                return 4  # Medium-high rate limit = medium-high priority
 | 
						|
            elif rate_limit >= 20:
 | 
						|
                return 6  # Medium rate limit = medium priority
 | 
						|
            elif rate_limit >= 5:
 | 
						|
                return 8  # Low rate limit = low priority
 | 
						|
            else:
 | 
						|
                return 10  # Very low rate limit = very low priority
 | 
						|
 | 
						|
    def _execute_scan(self, target: str, max_depth: int) -> None:
 | 
						|
        self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
 | 
						|
        processed_tasks = set()
 | 
						|
 | 
						|
        is_ip = _is_valid_ip(target)
 | 
						|
        initial_providers = [p for p in self._get_eligible_providers(target, is_ip, False) if not isinstance(p, CorrelationProvider)]
 | 
						|
        
 | 
						|
        for provider in initial_providers:
 | 
						|
            provider_name = provider.get_name()
 | 
						|
            priority = self._get_priority(provider_name)
 | 
						|
            self.task_queue.put((time.time(), priority, (provider_name, target, 0)))
 | 
						|
            self.total_tasks_ever_enqueued += 1
 | 
						|
 | 
						|
        try:
 | 
						|
            self.status = ScanStatus.RUNNING
 | 
						|
            self._update_session_state()
 | 
						|
 | 
						|
            enabled_providers = [provider.get_name() for provider in self.providers]
 | 
						|
            self.logger.log_scan_start(target, max_depth, enabled_providers)
 | 
						|
 | 
						|
            node_type = NodeType.IP if is_ip else NodeType.DOMAIN
 | 
						|
            self.graph.add_node(target, node_type)
 | 
						|
            self._initialize_provider_states(target)
 | 
						|
            consecutive_empty_iterations = 0
 | 
						|
            max_empty_iterations = 50
 | 
						|
 | 
						|
            print(f"\n=== PHASE 1: Running non-correlation providers ===")
 | 
						|
            while not self._is_stop_requested():
 | 
						|
                queue_empty = self.task_queue.empty()
 | 
						|
                with self.processing_lock:
 | 
						|
                    no_active_processing = len(self.currently_processing) == 0
 | 
						|
                
 | 
						|
                if queue_empty and no_active_processing:
 | 
						|
                    consecutive_empty_iterations += 1
 | 
						|
                    if consecutive_empty_iterations >= max_empty_iterations:
 | 
						|
                        break
 | 
						|
                    time.sleep(0.1)
 | 
						|
                    continue
 | 
						|
                else:
 | 
						|
                    consecutive_empty_iterations = 0
 | 
						|
 | 
						|
                try:
 | 
						|
                    run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
 | 
						|
                    if provider_name == 'correlation': continue
 | 
						|
                    current_time = time.time()
 | 
						|
                    if run_at > current_time:
 | 
						|
                        self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
 | 
						|
                        time.sleep(min(0.5, run_at - current_time))
 | 
						|
                        continue
 | 
						|
                except:
 | 
						|
                    time.sleep(0.1)
 | 
						|
                    continue
 | 
						|
 | 
						|
                self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
 | 
						|
                task_tuple = (provider_name, target_item, depth)
 | 
						|
                if task_tuple in processed_tasks or depth > max_depth:
 | 
						|
                    self.tasks_skipped += 1
 | 
						|
                    self.indicators_completed += 1
 | 
						|
                    continue
 | 
						|
                
 | 
						|
                if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
 | 
						|
                    defer_until = time.time() + 60
 | 
						|
                    self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
 | 
						|
                    self.tasks_re_enqueued += 1
 | 
						|
                    continue
 | 
						|
 | 
						|
                with self.processing_lock:
 | 
						|
                    if self._is_stop_requested(): break
 | 
						|
                    processing_key = (provider_name, target_item)
 | 
						|
                    if processing_key in self.currently_processing:
 | 
						|
                        self.tasks_skipped += 1
 | 
						|
                        self.indicators_completed += 1
 | 
						|
                        continue
 | 
						|
                    self.currently_processing.add(processing_key)
 | 
						|
 | 
						|
                try:
 | 
						|
                    self.current_depth = depth
 | 
						|
                    self.current_indicator = target_item
 | 
						|
                    self._update_session_state()
 | 
						|
                    if self._is_stop_requested(): break
 | 
						|
                    
 | 
						|
                    provider = next((p for p in self.providers if p.get_name() == provider_name), None)
 | 
						|
                    if provider and not isinstance(provider, CorrelationProvider):
 | 
						|
                        new_targets, _, success = self._process_provider_task(provider, target_item, depth)
 | 
						|
                        if self._is_stop_requested(): break
 | 
						|
 | 
						|
                        if not success:
 | 
						|
                            retry_key = (provider_name, target_item, depth)
 | 
						|
                            self.target_retries[retry_key] += 1
 | 
						|
                            if self.target_retries[retry_key] <= self.config.max_retries_per_target:
 | 
						|
                                retry_count = self.target_retries[retry_key]
 | 
						|
                                backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
 | 
						|
                                self.task_queue.put((time.time() + backoff_delay, priority, (provider_name, target_item, depth)))
 | 
						|
                                self.tasks_re_enqueued += 1
 | 
						|
                            else:
 | 
						|
                                self.scan_failed_due_to_retries = True
 | 
						|
                                self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded")
 | 
						|
                        else:
 | 
						|
                            processed_tasks.add(task_tuple)
 | 
						|
                            self.indicators_completed += 1
 | 
						|
 | 
						|
                        if not self._is_stop_requested():
 | 
						|
                            for new_target in new_targets:
 | 
						|
                                is_ip_new = _is_valid_ip(new_target)
 | 
						|
                                eligible_providers_new = [p for p in self._get_eligible_providers(new_target, is_ip_new, False) if not isinstance(p, CorrelationProvider)]
 | 
						|
                                for p_new in eligible_providers_new:
 | 
						|
                                    p_name_new = p_new.get_name()
 | 
						|
                                    new_depth = depth + 1
 | 
						|
                                    if (p_name_new, new_target, new_depth) not in processed_tasks and new_depth <= max_depth:
 | 
						|
                                        self.task_queue.put((time.time(), self._get_priority(p_name_new), (p_name_new, new_target, new_depth)))
 | 
						|
                                        self.total_tasks_ever_enqueued += 1
 | 
						|
                    else:
 | 
						|
                        self.tasks_skipped += 1
 | 
						|
                        self.indicators_completed += 1
 | 
						|
                finally:
 | 
						|
                    with self.processing_lock:
 | 
						|
                        self.currently_processing.discard((provider_name, target_item))
 | 
						|
            
 | 
						|
            # This code runs after the main loop finishes or is stopped.
 | 
						|
            self.status = ScanStatus.FINALIZING
 | 
						|
            self._update_session_state()
 | 
						|
            self.logger.logger.info("Scan stopped or completed. Entering finalization phase.")
 | 
						|
 | 
						|
            if self.status in [ScanStatus.FINALIZING, ScanStatus.COMPLETED, ScanStatus.STOPPED]:
 | 
						|
                print(f"\n=== PHASE 2: Running correlation analysis ===")
 | 
						|
                self._run_correlation_phase(max_depth, processed_tasks)
 | 
						|
                self._update_session_state()
 | 
						|
 | 
						|
            # Determine the final status *after* finalization.
 | 
						|
            if self._is_stop_requested():
 | 
						|
                self.status = ScanStatus.STOPPED
 | 
						|
            elif self.scan_failed_due_to_retries:
 | 
						|
                self.status = ScanStatus.FAILED
 | 
						|
            else:
 | 
						|
                self.status = ScanStatus.COMPLETED
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            traceback.print_exc()
 | 
						|
            self.status = ScanStatus.FAILED
 | 
						|
            self.logger.logger.error(f"Scan failed: {e}")
 | 
						|
        finally:
 | 
						|
            # The 'finally' block is now only for guaranteed cleanup.
 | 
						|
            with self.processing_lock:
 | 
						|
                self.currently_processing.clear()
 | 
						|
            self.currently_processing_display = []
 | 
						|
 | 
						|
            while not self.task_queue.empty():
 | 
						|
                try: self.task_queue.get_nowait()
 | 
						|
                except: break
 | 
						|
 | 
						|
            self.status_logger_stop_event.set()
 | 
						|
            if self.status_logger_thread and self.status_logger_thread.is_alive():
 | 
						|
                self.status_logger_thread.join(timeout=2.0)
 | 
						|
            
 | 
						|
            # The executor shutdown now happens *after* the correlation phase has run.
 | 
						|
            if self.executor:
 | 
						|
                try:
 | 
						|
                    self.executor.shutdown(wait=False, cancel_futures=True)
 | 
						|
                except Exception as e:
 | 
						|
                    self.logger.logger.warning(f"Error shutting down executor: {e}")
 | 
						|
                finally:
 | 
						|
                    self.executor = None
 | 
						|
            
 | 
						|
            self._update_session_state()
 | 
						|
            self.logger.log_scan_complete()
 | 
						|
 | 
						|
    def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
 | 
						|
        """
 | 
						|
        PHASE 2: Run correlation analysis on all discovered nodes.
 | 
						|
        Enhanced with better error handling and progress tracking.
 | 
						|
        """
 | 
						|
        correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
 | 
						|
        if not correlation_provider:
 | 
						|
            print("No correlation provider found - skipping correlation phase")
 | 
						|
            return
 | 
						|
 | 
						|
        # Ensure correlation provider has access to current graph state
 | 
						|
        correlation_provider.set_graph_manager(self.graph)
 | 
						|
        print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes")
 | 
						|
 | 
						|
        # Get all nodes from the graph for correlation analysis
 | 
						|
        all_nodes = list(self.graph.graph.nodes())
 | 
						|
        correlation_tasks = []
 | 
						|
        correlation_tasks_enqueued = 0
 | 
						|
        
 | 
						|
        print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
 | 
						|
        
 | 
						|
        for node_id in all_nodes:
 | 
						|
            # Determine appropriate depth for correlation (use 0 for simplicity)
 | 
						|
            correlation_depth = 0
 | 
						|
            task_tuple = ('correlation', node_id, correlation_depth)
 | 
						|
            
 | 
						|
            # Don't re-process already processed correlation tasks
 | 
						|
            if task_tuple not in processed_tasks:
 | 
						|
                priority = self._get_priority('correlation')
 | 
						|
                self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
 | 
						|
                correlation_tasks.append(task_tuple)
 | 
						|
                correlation_tasks_enqueued += 1
 | 
						|
                self.total_tasks_ever_enqueued += 1
 | 
						|
        
 | 
						|
        print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks")
 | 
						|
        
 | 
						|
        # Force session state update to reflect new task count
 | 
						|
        self._update_session_state()
 | 
						|
        
 | 
						|
        # Process correlation tasks with enhanced tracking
 | 
						|
        consecutive_empty_iterations = 0
 | 
						|
        max_empty_iterations = 20
 | 
						|
        correlation_completed = 0
 | 
						|
        correlation_errors = 0
 | 
						|
        
 | 
						|
        while correlation_tasks:
 | 
						|
            # Check if we should continue processing
 | 
						|
            queue_empty = self.task_queue.empty()
 | 
						|
            with self.processing_lock:
 | 
						|
                no_active_processing = len(self.currently_processing) == 0
 | 
						|
            
 | 
						|
            if queue_empty and no_active_processing:
 | 
						|
                consecutive_empty_iterations += 1
 | 
						|
                if consecutive_empty_iterations >= max_empty_iterations:
 | 
						|
                    print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining")
 | 
						|
                    break
 | 
						|
                time.sleep(0.1)
 | 
						|
                continue
 | 
						|
            else:
 | 
						|
                consecutive_empty_iterations = 0
 | 
						|
 | 
						|
            try:
 | 
						|
                run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
 | 
						|
                
 | 
						|
                # Only process correlation tasks in this phase
 | 
						|
                if provider_name != 'correlation':
 | 
						|
                    continue
 | 
						|
                    
 | 
						|
            except:
 | 
						|
                time.sleep(0.1)
 | 
						|
                continue
 | 
						|
 | 
						|
            task_tuple = (provider_name, target_item, depth)
 | 
						|
            
 | 
						|
            # Skip if already processed
 | 
						|
            if task_tuple in processed_tasks:
 | 
						|
                self.tasks_skipped += 1
 | 
						|
                self.indicators_completed += 1
 | 
						|
                if task_tuple in correlation_tasks:
 | 
						|
                    correlation_tasks.remove(task_tuple)
 | 
						|
                continue
 | 
						|
 | 
						|
            with self.processing_lock:
 | 
						|
                processing_key = (provider_name, target_item)
 | 
						|
                if processing_key in self.currently_processing:
 | 
						|
                    self.tasks_skipped += 1
 | 
						|
                    self.indicators_completed += 1
 | 
						|
                    continue
 | 
						|
                self.currently_processing.add(processing_key)
 | 
						|
 | 
						|
            try:
 | 
						|
                self.current_indicator = target_item
 | 
						|
                self._update_session_state()
 | 
						|
                
 | 
						|
                # Process correlation task with enhanced error handling
 | 
						|
                try:
 | 
						|
                    new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
 | 
						|
                    
 | 
						|
                    if success:
 | 
						|
                        processed_tasks.add(task_tuple)
 | 
						|
                        correlation_completed += 1
 | 
						|
                        self.indicators_completed += 1
 | 
						|
                        if task_tuple in correlation_tasks:
 | 
						|
                            correlation_tasks.remove(task_tuple)
 | 
						|
                    else:
 | 
						|
                        # For correlations, don't retry - just mark as completed
 | 
						|
                        correlation_errors += 1
 | 
						|
                        self.indicators_completed += 1
 | 
						|
                        if task_tuple in correlation_tasks:
 | 
						|
                            correlation_tasks.remove(task_tuple)
 | 
						|
                            
 | 
						|
                except Exception as e:
 | 
						|
                    correlation_errors += 1
 | 
						|
                    self.indicators_completed += 1
 | 
						|
                    if task_tuple in correlation_tasks:
 | 
						|
                        correlation_tasks.remove(task_tuple)
 | 
						|
                    self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}")
 | 
						|
                        
 | 
						|
            finally:
 | 
						|
                with self.processing_lock:
 | 
						|
                    processing_key = (provider_name, target_item)
 | 
						|
                    self.currently_processing.discard(processing_key)
 | 
						|
 | 
						|
            # Periodic progress update during correlation phase
 | 
						|
            if correlation_completed % 10 == 0 and correlation_completed > 0:
 | 
						|
                remaining = len(correlation_tasks)
 | 
						|
                print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining")
 | 
						|
 | 
						|
        print(f"Correlation phase complete:")
 | 
						|
        print(f"  - Successfully processed: {correlation_completed}")
 | 
						|
        print(f"  - Errors encountered: {correlation_errors}")
 | 
						|
        print(f"  - Tasks remaining: {len(correlation_tasks)}")
 | 
						|
 | 
						|
 | 
						|
    def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
 | 
						|
        """
 | 
						|
        Manages the entire process for a given target and provider.
 | 
						|
        This version is generalized to handle all relationships dynamically.
 | 
						|
        """
 | 
						|
        if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
 | 
						|
            return set(), set(), False
 | 
						|
    
 | 
						|
        is_ip = _is_valid_ip(target)
 | 
						|
        target_type = NodeType.IP if is_ip else NodeType.DOMAIN
 | 
						|
    
 | 
						|
        self.graph.add_node(target, target_type)
 | 
						|
        self._initialize_provider_states(target)
 | 
						|
    
 | 
						|
        new_targets = set()
 | 
						|
        provider_successful = True
 | 
						|
    
 | 
						|
        try:
 | 
						|
            provider_result = self._execute_provider_query(provider, target, is_ip)
 | 
						|
    
 | 
						|
            if provider_result is None:
 | 
						|
                provider_successful = False
 | 
						|
            # Allow correlation provider to process results even if scan is stopped
 | 
						|
            elif not self._is_stop_requested() or isinstance(provider, CorrelationProvider):
 | 
						|
                # Pass all relationships to be processed
 | 
						|
                discovered, is_large_entity = self._process_provider_result_unified(
 | 
						|
                    target, provider, provider_result, depth
 | 
						|
                )
 | 
						|
                new_targets.update(discovered)
 | 
						|
    
 | 
						|
        except Exception as e:
 | 
						|
            provider_successful = False
 | 
						|
            self._log_provider_error(target, provider.get_name(), str(e))
 | 
						|
    
 | 
						|
        return new_targets, set(), provider_successful
 | 
						|
 | 
						|
    def _execute_provider_query(self, provider: BaseProvider, target: str, is_ip: bool) -> Optional[ProviderResult]:
 | 
						|
        """
 | 
						|
        The "worker" function that directly communicates with the provider to fetch data.
 | 
						|
        """
 | 
						|
        provider_name = provider.get_name()
 | 
						|
        start_time = datetime.now(timezone.utc)
 | 
						|
        
 | 
						|
        if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
 | 
						|
            return None
 | 
						|
        
 | 
						|
        try:
 | 
						|
            if is_ip:
 | 
						|
                result = provider.query_ip(target)
 | 
						|
            else:
 | 
						|
                result = provider.query_domain(target)
 | 
						|
            
 | 
						|
            if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
 | 
						|
                return None
 | 
						|
            
 | 
						|
            relationship_count = result.get_relationship_count() if result else 0
 | 
						|
            self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time)
 | 
						|
            
 | 
						|
            return result
 | 
						|
            
 | 
						|
        except Exception as e:
 | 
						|
            self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time)
 | 
						|
            return None
 | 
						|
 | 
						|
    def _create_large_entity_from_result(self, source_node: str, provider_name: str,
 | 
						|
                                        provider_result: ProviderResult, depth: int) -> Tuple[str, Set[str]]:
 | 
						|
        """
 | 
						|
        Creates a large entity node, tags all member nodes, and stores original relationships.
 | 
						|
        FIXED: Now stores original relationships for later restoration during extraction.
 | 
						|
        """
 | 
						|
        members = {rel.target_node for rel in provider_result.relationships
 | 
						|
                if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)}
 | 
						|
        
 | 
						|
        if not members:
 | 
						|
            return "", set()
 | 
						|
 | 
						|
        large_entity_id = f"le_{provider_name}_{source_node}"
 | 
						|
        
 | 
						|
        # FIXED: Store original relationships for each member
 | 
						|
        member_relationships = {}
 | 
						|
        for rel in provider_result.relationships:
 | 
						|
            if rel.target_node in members:
 | 
						|
                if rel.target_node not in member_relationships:
 | 
						|
                    member_relationships[rel.target_node] = []
 | 
						|
                member_relationships[rel.target_node].append({
 | 
						|
                    'source_node': rel.source_node,
 | 
						|
                    'target_node': rel.target_node,
 | 
						|
                    'relationship_type': rel.relationship_type,
 | 
						|
                    'provider': rel.provider,
 | 
						|
                    'raw_data': rel.raw_data
 | 
						|
                })
 | 
						|
        
 | 
						|
        self.graph.add_node(
 | 
						|
            node_id=large_entity_id,
 | 
						|
            node_type=NodeType.LARGE_ENTITY,
 | 
						|
            attributes=[
 | 
						|
                {"name": "count", "value": len(members), "type": "statistic"},
 | 
						|
                {"name": "source_provider", "value": provider_name, "type": "metadata"},
 | 
						|
                {"name": "discovery_depth", "value": depth, "type": "metadata"},
 | 
						|
                {"name": "nodes", "value": list(members), "type": "metadata"},
 | 
						|
                {"name": "original_relationships", "value": member_relationships, "type": "metadata"}  # FIXED: Store original relationships
 | 
						|
            ],
 | 
						|
            description=f"A collection of {len(members)} nodes discovered from {source_node} via {provider_name}."
 | 
						|
        )
 | 
						|
 | 
						|
        for member_id in members:
 | 
						|
            node_type = NodeType.IP if _is_valid_ip(member_id) else NodeType.DOMAIN
 | 
						|
            self.graph.add_node(
 | 
						|
                node_id=member_id,
 | 
						|
                node_type=node_type,
 | 
						|
                metadata={'large_entity_id': large_entity_id}
 | 
						|
            )
 | 
						|
            
 | 
						|
        return large_entity_id, members
 | 
						|
 | 
						|
    def extract_node_from_large_entity(self, large_entity_id: str, node_id: str) -> bool:
 | 
						|
        """
 | 
						|
        Removes a node from a large entity and restores its original relationships.
 | 
						|
        FIXED: Now restores original relationships to make the node reachable.
 | 
						|
        """
 | 
						|
        if not self.graph.graph.has_node(node_id):
 | 
						|
            return False
 | 
						|
            
 | 
						|
        node_data = self.graph.graph.nodes[node_id]
 | 
						|
        metadata = node_data.get('metadata', {})
 | 
						|
        
 | 
						|
        if metadata.get('large_entity_id') != large_entity_id:
 | 
						|
            return False
 | 
						|
        
 | 
						|
        # Remove the large entity tag
 | 
						|
        del metadata['large_entity_id']
 | 
						|
        self.graph.add_node(node_id, NodeType(node_data['type']), metadata=metadata)
 | 
						|
        
 | 
						|
        # FIXED: Restore original relationships if they exist
 | 
						|
        if self.graph.graph.has_node(large_entity_id):
 | 
						|
            le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
 | 
						|
            original_relationships_attr = next((a for a in le_attrs if a['name'] == 'original_relationships'), None)
 | 
						|
            
 | 
						|
            if original_relationships_attr and node_id in original_relationships_attr['value']:
 | 
						|
                # Restore all original relationships for this node
 | 
						|
                for rel_data in original_relationships_attr['value'][node_id]:
 | 
						|
                    self.graph.add_edge(
 | 
						|
                        source_id=rel_data['source_node'],
 | 
						|
                        target_id=rel_data['target_node'],
 | 
						|
                        relationship_type=rel_data['relationship_type'],
 | 
						|
                        source_provider=rel_data['provider'],
 | 
						|
                        raw_data=rel_data['raw_data']
 | 
						|
                    )
 | 
						|
                    
 | 
						|
                    # Ensure both nodes exist in the graph
 | 
						|
                    source_type = NodeType.IP if _is_valid_ip(rel_data['source_node']) else NodeType.DOMAIN
 | 
						|
                    target_type = NodeType.IP if _is_valid_ip(rel_data['target_node']) else NodeType.DOMAIN
 | 
						|
                    self.graph.add_node(rel_data['source_node'], source_type)
 | 
						|
                    self.graph.add_node(rel_data['target_node'], target_type)
 | 
						|
                
 | 
						|
                # Update the large entity to remove this node from its list
 | 
						|
                nodes_attr = next((a for a in le_attrs if a['name'] == 'nodes'), None)
 | 
						|
                if nodes_attr and node_id in nodes_attr['value']:
 | 
						|
                    nodes_attr['value'].remove(node_id)
 | 
						|
                    
 | 
						|
                count_attr = next((a for a in le_attrs if a['name'] == 'count'), None)
 | 
						|
                if count_attr:
 | 
						|
                    count_attr['value'] = max(0, count_attr['value'] - 1)
 | 
						|
                
 | 
						|
                # Remove from original relationships tracking
 | 
						|
                if node_id in original_relationships_attr['value']:
 | 
						|
                    del original_relationships_attr['value'][node_id]
 | 
						|
        
 | 
						|
        # Re-enqueue the node for full processing
 | 
						|
        is_ip = _is_valid_ip(node_id)
 | 
						|
        eligible_providers = self._get_eligible_providers(node_id, is_ip, False, is_extracted=True)
 | 
						|
        for provider in eligible_providers:
 | 
						|
            provider_name = provider.get_name()
 | 
						|
            priority = self._get_priority(provider_name)
 | 
						|
            # Use current depth of the large entity if available, else 0
 | 
						|
            depth = 0
 | 
						|
            if self.graph.graph.has_node(large_entity_id):
 | 
						|
                le_attrs = self.graph.graph.nodes[large_entity_id].get('attributes', [])
 | 
						|
                depth_attr = next((a for a in le_attrs if a['name'] == 'discovery_depth'), None)
 | 
						|
                if depth_attr:
 | 
						|
                    depth = depth_attr['value']
 | 
						|
 | 
						|
            self.task_queue.put((time.time(), priority, (provider_name, node_id, depth)))
 | 
						|
            self.total_tasks_ever_enqueued += 1
 | 
						|
        
 | 
						|
        return True
 | 
						|
 | 
						|
    def _process_provider_result_unified(self, target: str, provider: BaseProvider,
 | 
						|
                                        provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
 | 
						|
        """
 | 
						|
        Process a unified ProviderResult object to update the graph.
 | 
						|
        This version dynamically re-routes edges to a large entity container.
 | 
						|
        """
 | 
						|
        provider_name = provider.get_name()
 | 
						|
        discovered_targets = set()
 | 
						|
        large_entity_id = ""
 | 
						|
        large_entity_members = set()
 | 
						|
 | 
						|
        # Stop processing for non-correlation providers if requested
 | 
						|
        if self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
 | 
						|
            return discovered_targets, False
 | 
						|
 | 
						|
        eligible_rel_count = sum(
 | 
						|
            1 for rel in provider_result.relationships if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node)
 | 
						|
        )
 | 
						|
        is_large_entity = eligible_rel_count > self.config.large_entity_threshold
 | 
						|
 | 
						|
        if is_large_entity:
 | 
						|
            large_entity_id, large_entity_members = self._create_large_entity_from_result(
 | 
						|
                target, provider_name, provider_result, current_depth
 | 
						|
            )
 | 
						|
 | 
						|
        for i, relationship in enumerate(provider_result.relationships):
 | 
						|
            # Stop processing for non-correlation providers if requested
 | 
						|
            if i % 5 == 0 and self._is_stop_requested() and not isinstance(provider, CorrelationProvider):
 | 
						|
                break
 | 
						|
 | 
						|
            source_node_id = relationship.source_node
 | 
						|
            target_node_id = relationship.target_node
 | 
						|
 | 
						|
            # Determine visual source and target, substituting with large entity ID if necessary
 | 
						|
            visual_source = large_entity_id if source_node_id in large_entity_members else source_node_id
 | 
						|
            visual_target = large_entity_id if target_node_id in large_entity_members else target_node_id
 | 
						|
 | 
						|
            # Prevent self-loops on the large entity node
 | 
						|
            if visual_source == visual_target:
 | 
						|
                continue
 | 
						|
 | 
						|
            # Determine node types for the actual nodes
 | 
						|
            source_type = NodeType.IP if _is_valid_ip(source_node_id) else NodeType.DOMAIN
 | 
						|
            if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
 | 
						|
                target_type = NodeType.ISP
 | 
						|
            elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
 | 
						|
                target_type = NodeType.CA
 | 
						|
            elif provider_name == 'correlation':
 | 
						|
                target_type = NodeType.CORRELATION_OBJECT
 | 
						|
            elif _is_valid_ip(target_node_id):
 | 
						|
                target_type = NodeType.IP
 | 
						|
            else:
 | 
						|
                target_type = NodeType.DOMAIN
 | 
						|
 | 
						|
            max_depth_reached = current_depth >= self.max_depth
 | 
						|
 | 
						|
            # Add actual nodes to the graph (they might be hidden by the UI)
 | 
						|
            self.graph.add_node(source_node_id, source_type)
 | 
						|
            self.graph.add_node(target_node_id, target_type, metadata={'max_depth_reached': max_depth_reached})
 | 
						|
            
 | 
						|
            # Add the visual edge to the graph
 | 
						|
            self.graph.add_edge(
 | 
						|
                visual_source, visual_target,
 | 
						|
                relationship.relationship_type,
 | 
						|
                provider_name,
 | 
						|
                relationship.raw_data
 | 
						|
            )
 | 
						|
            
 | 
						|
            if (_is_valid_domain(target_node_id) or _is_valid_ip(target_node_id)) and not max_depth_reached:
 | 
						|
                if target_node_id not in large_entity_members:
 | 
						|
                    discovered_targets.add(target_node_id)
 | 
						|
 | 
						|
        if large_entity_members:
 | 
						|
            self.logger.logger.info(f"Enqueuing DNS and Correlation for {len(large_entity_members)} members of {large_entity_id}")
 | 
						|
            for member in large_entity_members:
 | 
						|
                for provider_name_to_run in ['dns', 'correlation']:
 | 
						|
                    p_instance = next((p for p in self.providers if p.get_name() == provider_name_to_run), None)
 | 
						|
                    if p_instance and p_instance.get_eligibility().get('domains' if _is_valid_domain(member) else 'ips'):
 | 
						|
                        priority = self._get_priority(provider_name_to_run)
 | 
						|
                        self.task_queue.put((time.time(), priority, (provider_name_to_run, member, current_depth)))
 | 
						|
                        self.total_tasks_ever_enqueued += 1
 | 
						|
        
 | 
						|
        attributes_by_node = defaultdict(list)
 | 
						|
        for attribute in provider_result.attributes:
 | 
						|
            attr_dict = {
 | 
						|
                "name": attribute.name, "value": attribute.value, "type": attribute.type,
 | 
						|
                "provider": attribute.provider, "metadata": attribute.metadata
 | 
						|
            }
 | 
						|
            attributes_by_node[attribute.target_node].append(attr_dict)
 | 
						|
 | 
						|
        for node_id, node_attributes_list in attributes_by_node.items():
 | 
						|
            if not self.graph.graph.has_node(node_id):
 | 
						|
                node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
 | 
						|
                self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
 | 
						|
            else:
 | 
						|
                existing_attrs = self.graph.graph.nodes[node_id].get('attributes', [])
 | 
						|
                self.graph.graph.nodes[node_id]['attributes'] = existing_attrs + node_attributes_list
 | 
						|
 | 
						|
        return discovered_targets, is_large_entity
 | 
						|
 | 
						|
    def stop_scan(self) -> bool:
 | 
						|
        """Request immediate scan termination with proper cleanup."""
 | 
						|
        try:
 | 
						|
            self.logger.logger.info("Scan termination requested by user")
 | 
						|
            self._set_stop_signal()
 | 
						|
            self.status = ScanStatus.STOPPED
 | 
						|
            
 | 
						|
            with self.processing_lock:
 | 
						|
                self.currently_processing.clear()
 | 
						|
            self.currently_processing_display = []
 | 
						|
            
 | 
						|
            self.task_queue = PriorityQueue()
 | 
						|
            
 | 
						|
            if self.executor:
 | 
						|
                try:
 | 
						|
                    self.executor.shutdown(wait=False, cancel_futures=True)
 | 
						|
                except Exception:
 | 
						|
                    pass
 | 
						|
 | 
						|
            self._update_session_state()
 | 
						|
            return True
 | 
						|
            
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.error(f"Error during scan termination: {e}")
 | 
						|
            traceback.print_exc()
 | 
						|
            return False
 | 
						|
 | 
						|
    def _update_session_state(self) -> None:
 | 
						|
        """
 | 
						|
        Update the scanner state in Redis for GUI updates.
 | 
						|
        """
 | 
						|
        if self.session_id:
 | 
						|
            try:
 | 
						|
                from core.session_manager import session_manager
 | 
						|
                session_manager.update_session_scanner(self.session_id, self)
 | 
						|
            except Exception:
 | 
						|
                pass
 | 
						|
 | 
						|
    def get_scan_status(self) -> Dict[str, Any]:
 | 
						|
        """Get current scan status with comprehensive processing information."""
 | 
						|
        try:
 | 
						|
            with self.processing_lock:
 | 
						|
                currently_processing_count = len(self.currently_processing)
 | 
						|
                currently_processing_list = list(self.currently_processing)
 | 
						|
            
 | 
						|
            return {
 | 
						|
                'status': self.status,
 | 
						|
                'target_domain': self.current_target,
 | 
						|
                'current_depth': self.current_depth,
 | 
						|
                'max_depth': self.max_depth,
 | 
						|
                'current_indicator': self.current_indicator,
 | 
						|
                'indicators_processed': self.indicators_processed,
 | 
						|
                'indicators_completed': self.indicators_completed,
 | 
						|
                'tasks_re_enqueued': self.tasks_re_enqueued,
 | 
						|
                'progress_percentage': self._calculate_progress(),
 | 
						|
                'total_tasks_ever_enqueued': self.total_tasks_ever_enqueued,
 | 
						|
                'enabled_providers': [provider.get_name() for provider in self.providers],
 | 
						|
                'graph_statistics': self.graph.get_statistics(),
 | 
						|
                'task_queue_size': self.task_queue.qsize(),
 | 
						|
                'currently_processing_count': currently_processing_count,
 | 
						|
                'currently_processing': currently_processing_list[:5],
 | 
						|
                'tasks_in_queue': self.task_queue.qsize(),
 | 
						|
                'tasks_completed': self.indicators_completed,
 | 
						|
                'tasks_skipped': self.tasks_skipped,
 | 
						|
                'tasks_rescheduled': self.tasks_re_enqueued,
 | 
						|
            }
 | 
						|
        except Exception:
 | 
						|
            traceback.print_exc()
 | 
						|
            return { 'status': 'error', 'message': 'Failed to get status' }
 | 
						|
 | 
						|
    def _initialize_provider_states(self, target: str) -> None:
 | 
						|
        """
 | 
						|
        FIXED: Safer provider state initialization with error handling.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            if not self.graph.graph.has_node(target):
 | 
						|
                return
 | 
						|
                
 | 
						|
            node_data = self.graph.graph.nodes[target]
 | 
						|
            if 'metadata' not in node_data:
 | 
						|
                node_data['metadata'] = {}
 | 
						|
            if 'provider_states' not in node_data['metadata']:
 | 
						|
                node_data['metadata']['provider_states'] = {}
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
 | 
						|
 | 
						|
 | 
						|
    def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool, is_extracted: bool = False) -> List:
 | 
						|
        """
 | 
						|
        FIXED: Improved provider eligibility checking with better filtering.
 | 
						|
        """
 | 
						|
        if dns_only:
 | 
						|
            return [p for p in self.providers if p.get_name() == 'dns']
 | 
						|
        
 | 
						|
        eligible = []
 | 
						|
        target_key = 'ips' if is_ip else 'domains'
 | 
						|
        
 | 
						|
        # Check if the target is part of a large entity
 | 
						|
        is_in_large_entity = False
 | 
						|
        if self.graph.graph.has_node(target) and not is_extracted:
 | 
						|
            metadata = self.graph.graph.nodes[target].get('metadata', {})
 | 
						|
            if 'large_entity_id' in metadata:
 | 
						|
                is_in_large_entity = True
 | 
						|
        
 | 
						|
        for provider in self.providers:
 | 
						|
            try:
 | 
						|
                # If in large entity, only allow dns and correlation providers
 | 
						|
                if is_in_large_entity and provider.get_name() not in ['dns', 'correlation']:
 | 
						|
                    continue
 | 
						|
                    
 | 
						|
                # Check if provider supports this target type
 | 
						|
                if not provider.get_eligibility().get(target_key, False):
 | 
						|
                    continue
 | 
						|
                    
 | 
						|
                # Check if provider is available/configured
 | 
						|
                if not provider.is_available():
 | 
						|
                    continue
 | 
						|
                    
 | 
						|
                # Check if we already successfully queried this provider
 | 
						|
                if not self._already_queried_provider(target, provider.get_name()):
 | 
						|
                    eligible.append(provider)
 | 
						|
                    
 | 
						|
            except Exception as e:
 | 
						|
                self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}")
 | 
						|
                continue
 | 
						|
        
 | 
						|
        return eligible
 | 
						|
 | 
						|
    def _already_queried_provider(self, target: str, provider_name: str) -> bool:
 | 
						|
        """
 | 
						|
        FIXED: More robust check for already queried providers with proper error handling.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            if not self.graph.graph.has_node(target):
 | 
						|
                return False
 | 
						|
                
 | 
						|
            node_data = self.graph.graph.nodes[target]
 | 
						|
            provider_states = node_data.get('metadata', {}).get('provider_states', {})
 | 
						|
            provider_state = provider_states.get(provider_name)
 | 
						|
            
 | 
						|
            # Only consider it already queried if it was successful
 | 
						|
            return (provider_state is not None and 
 | 
						|
                    provider_state.get('status') == 'success' and
 | 
						|
                    provider_state.get('results_count', 0) > 0)
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}")
 | 
						|
            return False
 | 
						|
 | 
						|
    def _update_provider_state(self, target: str, provider_name: str, status: str, 
 | 
						|
                            results_count: int, error: Optional[str], start_time: datetime) -> None:
 | 
						|
        """
 | 
						|
        FIXED: More robust provider state updates with validation.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            if not self.graph.graph.has_node(target):
 | 
						|
                self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
 | 
						|
                return
 | 
						|
                
 | 
						|
            node_data = self.graph.graph.nodes[target]
 | 
						|
            if 'metadata' not in node_data:
 | 
						|
                node_data['metadata'] = {}
 | 
						|
            if 'provider_states' not in node_data['metadata']:
 | 
						|
                node_data['metadata']['provider_states'] = {}
 | 
						|
            
 | 
						|
            # Calculate duration safely
 | 
						|
            try:
 | 
						|
                duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
 | 
						|
            except Exception:
 | 
						|
                duration_ms = 0
 | 
						|
                
 | 
						|
            node_data['metadata']['provider_states'][provider_name] = {
 | 
						|
                'status': status,
 | 
						|
                'timestamp': start_time.isoformat(),
 | 
						|
                'results_count': max(0, results_count),  # Ensure non-negative
 | 
						|
                'error': str(error) if error else None,
 | 
						|
                'duration_ms': duration_ms
 | 
						|
            }
 | 
						|
            
 | 
						|
            # Update last modified time for forensic integrity
 | 
						|
            self.last_modified = datetime.now(timezone.utc).isoformat()
 | 
						|
            
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
 | 
						|
 | 
						|
    def _log_target_processing_error(self, target: str, error: str) -> None:
 | 
						|
        self.logger.logger.error(f"Target processing failed for {target}: {error}")
 | 
						|
 | 
						|
    def _log_provider_error(self, target: str, provider_name: str, error: str) -> None:
 | 
						|
        self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
 | 
						|
 | 
						|
    def _calculate_progress(self) -> float:
 | 
						|
        """
 | 
						|
        Enhanced progress calculation that properly accounts for correlation tasks
 | 
						|
        added during the correlation phase.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            if self.total_tasks_ever_enqueued == 0:
 | 
						|
                return 0.0
 | 
						|
                
 | 
						|
            # Add small buffer for tasks still in queue to avoid showing 100% too early
 | 
						|
            queue_size = max(0, self.task_queue.qsize())
 | 
						|
            with self.processing_lock:
 | 
						|
                active_tasks = len(self.currently_processing)
 | 
						|
            
 | 
						|
            # For correlation phase, be more conservative about progress calculation
 | 
						|
            if self.status == ScanStatus.FINALIZING:
 | 
						|
                # During correlation phase, show progress more conservatively
 | 
						|
                base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100
 | 
						|
                
 | 
						|
                # If we have active correlation tasks, cap progress at 95% until done
 | 
						|
                if queue_size > 0 or active_tasks > 0:
 | 
						|
                    return min(95.0, base_progress)
 | 
						|
                else:
 | 
						|
                    return min(100.0, base_progress)
 | 
						|
            
 | 
						|
            # Normal phase progress calculation
 | 
						|
            adjusted_total = max(self.total_tasks_ever_enqueued, 
 | 
						|
                            self.indicators_completed + queue_size + active_tasks)
 | 
						|
            
 | 
						|
            if adjusted_total == 0:
 | 
						|
                return 100.0
 | 
						|
                
 | 
						|
            progress = (self.indicators_completed / adjusted_total) * 100
 | 
						|
            return max(0.0, min(100.0, progress))  # Clamp between 0 and 100
 | 
						|
            
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.warning(f"Error calculating progress: {e}")
 | 
						|
            return 0.0
 | 
						|
 | 
						|
    def get_graph_data(self) -> Dict[str, Any]:
 | 
						|
        graph_data = self.graph.get_graph_data()
 | 
						|
        graph_data['initial_targets'] = list(self.initial_targets)
 | 
						|
        return graph_data
 | 
						|
 | 
						|
    
 | 
						|
    def get_provider_info(self) -> Dict[str, Dict[str, Any]]:
 | 
						|
        info = {}
 | 
						|
        provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
 | 
						|
        for filename in os.listdir(provider_dir):
 | 
						|
            if filename.endswith('_provider.py') and not filename.startswith('base'):
 | 
						|
                module_name = f"providers.{filename[:-3]}"
 | 
						|
                try:
 | 
						|
                    module = importlib.import_module(module_name)
 | 
						|
                    for attribute_name in dir(module):
 | 
						|
                        attribute = getattr(module, attribute_name)
 | 
						|
                        if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
 | 
						|
                            provider_class = attribute
 | 
						|
                            temp_provider = provider_class(name=attribute_name, session_config=self.config)
 | 
						|
                            provider_name = temp_provider.get_name()
 | 
						|
                            live_provider = next((p for p in self.providers if p.get_name() == provider_name), None)
 | 
						|
                            info[provider_name] = {
 | 
						|
                                'display_name': temp_provider.get_display_name(),
 | 
						|
                                'requires_api_key': temp_provider.requires_api_key(),
 | 
						|
                                'statistics': live_provider.get_statistics() if live_provider else temp_provider.get_statistics(),
 | 
						|
                                'enabled': self.config.is_provider_enabled(provider_name),
 | 
						|
                                'rate_limit': self.config.get_rate_limit(provider_name),
 | 
						|
                            }
 | 
						|
                except Exception:
 | 
						|
                    traceback.print_exc()
 | 
						|
        return info |