diff --git a/config.py b/config.py index ef62486..3f4eb3f 100644 --- a/config.py +++ b/config.py @@ -34,7 +34,7 @@ class Config: 'crtsh': 5, 'shodan': 60, 'dns': 100, - 'correlation': 1000 # Set a high limit as it's a local operation + 'correlation': 0 # Set to 0 to make sure correlations run last } # --- Provider Settings --- diff --git a/core/scanner.py b/core/scanner.py index 30819d5..0c8f474 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -258,20 +258,36 @@ class Scanner: time.sleep(2) def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: - """ - Starts a new reconnaissance scan. - """ if self.scan_thread and self.scan_thread.is_alive(): + self.logger.logger.info("Stopping existing scan before starting new one") self._set_stop_signal() self.status = ScanStatus.STOPPED + + # Clean up processing state with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] - self.task_queue = PriorityQueue() + + # Clear task queue + while not self.task_queue.empty(): + try: + self.task_queue.get_nowait() + except: + break + + # Shutdown executor if self.executor: - self.executor.shutdown(wait=False, cancel_futures=True) - self.executor = None - self.scan_thread.join(5.0) + try: + self.executor.shutdown(wait=False, cancel_futures=True) + except: + pass + finally: + self.executor = None + + # Wait for scan thread to finish (with timeout) + self.scan_thread.join(timeout=5.0) + if self.scan_thread.is_alive(): + self.logger.logger.warning("Previous scan thread did not terminate cleanly") self.status = ScanStatus.IDLE self.stop_event.clear() @@ -294,20 +310,40 @@ class Scanner: try: if not hasattr(self, 'providers') or not self.providers: + self.logger.logger.error("No providers available for scanning") return False + available_providers = [p for p in self.providers if p.is_available()] + if not available_providers: + self.logger.logger.error("No providers are currently available/configured") + return False + if clear_graph: self.graph.clear() self.initial_targets.clear() if force_rescan_target and self.graph.graph.has_node(force_rescan_target): - node_data = self.graph.graph.nodes[force_rescan_target] - if 'metadata' in node_data and 'provider_states' in node_data['metadata']: - node_data['metadata']['provider_states'] = {} + try: + node_data = self.graph.graph.nodes[force_rescan_target] + if 'metadata' in node_data and 'provider_states' in node_data['metadata']: + node_data['metadata']['provider_states'] = {} + self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}") + except Exception as e: + self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}") - self.current_target = target.lower().strip() + target = target.lower().strip() + if not target: + self.logger.logger.error("Empty target provided") + return False + + from utils.helpers import is_valid_target + if not is_valid_target(target): + self.logger.logger.error(f"Invalid target format: {target}") + return False + + self.current_target = target self.initial_targets.add(self.current_target) - self.max_depth = max_depth + self.max_depth = max(1, min(5, max_depth)) # Clamp depth between 1-5 self.current_depth = 0 self.total_indicators_found = 0 @@ -320,56 +356,77 @@ class Scanner: self._update_session_state() self.logger = new_session() - self.scan_thread = threading.Thread( - target=self._execute_scan, - args=(self.current_target, max_depth), - daemon=True - ) - self.scan_thread.start() + try: + self.scan_thread = threading.Thread( + target=self._execute_scan, + args=(self.current_target, self.max_depth), + daemon=True, + name=f"ScanThread-{self.session_id or 'default'}" + ) + self.scan_thread.start() - self.status_logger_stop_event.clear() - self.status_logger_thread = threading.Thread(target=self._status_logger_thread, daemon=True) - self.status_logger_thread.start() + self.status_logger_stop_event.clear() + self.status_logger_thread = threading.Thread( + target=self._status_logger_thread, + daemon=True, + name=f"StatusLogger-{self.session_id or 'default'}" + ) + self.status_logger_thread.start() - return True + self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}") + return True + + except Exception as e: + self.logger.logger.error(f"Error starting scan threads: {e}") + self.status = ScanStatus.FAILED + self._update_session_state() + return False except Exception as e: + self.logger.logger.error(f"Error in scan startup: {e}") traceback.print_exc() self.status = ScanStatus.FAILED self._update_session_state() return False def _get_priority(self, provider_name): + if provider_name == 'correlation': + return 100 # Highest priority number = lowest priority (runs last) + rate_limit = self.config.get_rate_limit(provider_name) - # Define the logarithmic scale - if rate_limit < 10: - return 10 # Highest priority number (lowest priority) for very low rate limits - - # Calculate logarithmic value and map to priority levels - # Lower rate limits get higher priority numbers (lower priority) - log_value = math.log10(rate_limit) - priority = 10 - int(log_value * 2) # Scale factor to get more granular levels - - # Ensure priority is within a reasonable range (1-10) - priority = max(1, min(10, priority)) - - return priority + # Handle edge cases + if rate_limit <= 0: + return 90 # Very low priority for invalid/disabled providers + + if provider_name == 'dns': + return 1 # DNS is fastest, should run first + elif provider_name == 'shodan': + return 3 # Shodan is medium speed, good priority + elif provider_name == 'crtsh': + return 5 # crt.sh is slower, lower priority + else: + # For any other providers, use rate limit as a guide + if rate_limit >= 100: + return 2 # High rate limit = high priority + elif rate_limit >= 50: + return 4 # Medium-high rate limit = medium-high priority + elif rate_limit >= 20: + return 6 # Medium rate limit = medium priority + elif rate_limit >= 5: + return 8 # Low rate limit = low priority + else: + return 10 # Very low rate limit = very low priority def _execute_scan(self, target: str, max_depth: int) -> None: - """ - Execute the reconnaissance scan with a time-based, robust scheduler. - Handles rate-limiting via deferral and failures via exponential backoff. - """ self.executor = ThreadPoolExecutor(max_workers=self.max_workers) - processed_tasks = set() + processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping is_ip = _is_valid_ip(target) initial_providers = self._get_eligible_providers(target, is_ip, False) for provider in initial_providers: provider_name = provider.get_name() priority = self._get_priority(provider_name) - # OVERHAUL: Enqueue with current timestamp to run immediately self.task_queue.put((time.time(), priority, (provider_name, target, 0))) self.total_tasks_ever_enqueued += 1 @@ -383,101 +440,156 @@ class Scanner: node_type = NodeType.IP if is_ip else NodeType.DOMAIN self.graph.add_node(target, node_type) self._initialize_provider_states(target) + consecutive_empty_iterations = 0 + max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion while not self._is_stop_requested(): - if self.task_queue.empty() and not self.currently_processing: - break # Scan is complete + queue_empty = self.task_queue.empty() + with self.processing_lock: + no_active_processing = len(self.currently_processing) == 0 + + if queue_empty and no_active_processing: + consecutive_empty_iterations += 1 + if consecutive_empty_iterations >= max_empty_iterations: + break # Scan is complete + time.sleep(0.1) + continue + else: + consecutive_empty_iterations = 0 + # FIXED: Safe task retrieval without race conditions try: - # OVERHAUL: Peek at the next task to see if it's ready to run - next_run_at, _, _ = self.task_queue.queue[0] - if next_run_at > time.time(): - time.sleep(0.1) # Sleep to prevent busy-waiting for future tasks - continue + # Use timeout to avoid blocking indefinitely + run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) - # Task is ready, so get it from the queue - run_at, priority, (provider_name, target_item, depth) = self.task_queue.get() - self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) - - except IndexError: - time.sleep(0.1) # Queue is empty, but tasks might still be processing + # FIXED: Check if task is ready to run + current_time = time.time() + if run_at > current_time: + # Task is not ready yet, re-queue it and continue + self.task_queue.put((run_at, priority, (provider_name, target_item, depth))) + time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time + continue + + except: # Queue is empty or timeout occurred + time.sleep(0.1) continue - task_tuple = (provider_name, target_item) + self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) + + # FIXED: Include depth in processed tasks to avoid incorrect skipping + task_tuple = (provider_name, target_item, depth) if task_tuple in processed_tasks: self.tasks_skipped += 1 - self.indicators_completed +=1 + self.indicators_completed += 1 continue + # FIXED: Proper depth checking if depth > max_depth: + self.tasks_skipped += 1 + self.indicators_completed += 1 continue - # OVERHAUL: Handle rate limiting with time-based deferral + # FIXED: Rate limiting with proper time-based deferral if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): - defer_until = time.time() + 60 # Defer for 60 seconds + defer_until = time.time() + 60 # Defer for 60 seconds self.task_queue.put((defer_until, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 continue + # FIXED: Thread-safe processing state management with self.processing_lock: - if self._is_stop_requested(): break - self.currently_processing.add(task_tuple) + if self._is_stop_requested(): + break + # Use provider+target (without depth) for duplicate processing check + processing_key = (provider_name, target_item) + if processing_key in self.currently_processing: + # Already processing this provider+target combination, skip + self.tasks_skipped += 1 + self.indicators_completed += 1 + continue + self.currently_processing.add(processing_key) try: self.current_depth = depth self.current_indicator = target_item self._update_session_state() - if self._is_stop_requested(): break + if self._is_stop_requested(): + break provider = next((p for p in self.providers if p.get_name() == provider_name), None) if provider: new_targets, _, success = self._process_provider_task(provider, target_item, depth) - if self._is_stop_requested(): break + if self._is_stop_requested(): + break if not success: - self.target_retries[task_tuple] += 1 - if self.target_retries[task_tuple] <= self.config.max_retries_per_target: - # OVERHAUL: Exponential backoff for retries - retry_count = self.target_retries[task_tuple] - backoff_delay = (2 ** retry_count) + random.uniform(0, 1) # Add jitter + # FIXED: Use depth-aware retry key + retry_key = (provider_name, target_item, depth) + self.target_retries[retry_key] += 1 + + if self.target_retries[retry_key] <= self.config.max_retries_per_target: + # FIXED: Exponential backoff with jitter for retries + retry_count = self.target_retries[retry_key] + backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes retry_at = time.time() + backoff_delay self.task_queue.put((retry_at, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 + self.logger.logger.debug(f"Retrying {provider_name}:{target_item} in {backoff_delay:.1f}s (attempt {retry_count})") else: self.scan_failed_due_to_retries = True - self._log_target_processing_error(str(task_tuple), "Max retries exceeded") + self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded") else: processed_tasks.add(task_tuple) self.indicators_completed += 1 + # FIXED: Enqueue new targets with proper depth tracking if not self._is_stop_requested(): for new_target in new_targets: is_ip_new = _is_valid_ip(new_target) eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) + for p_new in eligible_providers_new: p_name_new = p_new.get_name() - if (p_name_new, new_target) not in processed_tasks: - new_depth = depth + 1 if new_target in new_targets else depth + new_depth = depth + 1 # Always increment depth for discovered targets + new_task_tuple = (p_name_new, new_target, new_depth) + + # FIXED: Don't re-enqueue already processed tasks + if new_task_tuple not in processed_tasks and new_depth <= max_depth: new_priority = self._get_priority(p_name_new) - # OVERHAUL: Enqueue new tasks to run immediately + # Enqueue new tasks to run immediately self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth))) self.total_tasks_ever_enqueued += 1 + else: + self.logger.logger.warning(f"Provider {provider_name} not found in active providers") + self.tasks_skipped += 1 + self.indicators_completed += 1 + finally: + # FIXED: Always clean up processing state with self.processing_lock: - self.currently_processing.discard(task_tuple) + processing_key = (provider_name, target_item) + self.currently_processing.discard(processing_key) except Exception as e: traceback.print_exc() self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: + # FIXED: Comprehensive cleanup with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] + # FIXED: Clear any remaining tasks from queue to prevent memory leaks + while not self.task_queue.empty(): + try: + self.task_queue.get_nowait() + except: + break + if self._is_stop_requested(): self.status = ScanStatus.STOPPED elif self.scan_failed_due_to_retries: @@ -486,14 +598,19 @@ class Scanner: self.status = ScanStatus.COMPLETED self.status_logger_stop_event.set() - if self.status_logger_thread: - self.status_logger_thread.join() + if self.status_logger_thread and self.status_logger_thread.is_alive(): + self.status_logger_thread.join(timeout=2.0) # Don't wait forever self._update_session_state() self.logger.log_scan_complete() + if self.executor: - self.executor.shutdown(wait=False, cancel_futures=True) - self.executor = None + try: + self.executor.shutdown(wait=False, cancel_futures=True) + except Exception as e: + self.logger.logger.warning(f"Error shutting down executor: {e}") + finally: + self.executor = None def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: """ @@ -837,47 +954,108 @@ class Scanner: return { 'status': 'error', 'message': 'Failed to get status' } def _initialize_provider_states(self, target: str) -> None: - """Initialize provider states for forensic tracking.""" - if not self.graph.graph.has_node(target): return - node_data = self.graph.graph.nodes[target] - if 'metadata' not in node_data: node_data['metadata'] = {} - if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} + """ + FIXED: Safer provider state initialization with error handling. + """ + try: + if not self.graph.graph.has_node(target): + return + + node_data = self.graph.graph.nodes[target] + if 'metadata' not in node_data: + node_data['metadata'] = {} + if 'provider_states' not in node_data['metadata']: + node_data['metadata']['provider_states'] = {} + except Exception as e: + self.logger.logger.warning(f"Error initializing provider states for {target}: {e}") + def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: - """Get providers eligible for querying this target.""" + """ + FIXED: Improved provider eligibility checking with better filtering. + """ if dns_only: return [p for p in self.providers if p.get_name() == 'dns'] + eligible = [] target_key = 'ips' if is_ip else 'domains' + for provider in self.providers: - if provider.get_eligibility().get(target_key): + try: + # Check if provider supports this target type + if not provider.get_eligibility().get(target_key, False): + continue + + # Check if provider is available/configured + if not provider.is_available(): + continue + + # Check if we already successfully queried this provider if not self._already_queried_provider(target, provider.get_name()): eligible.append(provider) + + except Exception as e: + self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}") + continue + return eligible def _already_queried_provider(self, target: str, provider_name: str) -> bool: - """Check if we already successfully queried a provider for a target.""" - if not self.graph.graph.has_node(target): return False - node_data = self.graph.graph.nodes[target] - provider_states = node_data.get('metadata', {}).get('provider_states', {}) - provider_state = provider_states.get(provider_name) - return provider_state is not None and provider_state.get('status') == 'success' + """ + FIXED: More robust check for already queried providers with proper error handling. + """ + try: + if not self.graph.graph.has_node(target): + return False + + node_data = self.graph.graph.nodes[target] + provider_states = node_data.get('metadata', {}).get('provider_states', {}) + provider_state = provider_states.get(provider_name) + + # Only consider it already queried if it was successful + return (provider_state is not None and + provider_state.get('status') == 'success' and + provider_state.get('results_count', 0) > 0) + except Exception as e: + self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}") + return False def _update_provider_state(self, target: str, provider_name: str, status: str, - results_count: int, error: Optional[str], start_time: datetime) -> None: - """Update provider state in node metadata for forensic tracking.""" - if not self.graph.graph.has_node(target): return - node_data = self.graph.graph.nodes[target] - if 'metadata' not in node_data: node_data['metadata'] = {} - if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} - node_data['metadata']['provider_states'][provider_name] = { - 'status': status, - 'timestamp': start_time.isoformat(), - 'results_count': results_count, - 'error': error, - 'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 - } - + results_count: int, error: Optional[str], start_time: datetime) -> None: + """ + FIXED: More robust provider state updates with validation. + """ + try: + if not self.graph.graph.has_node(target): + self.logger.logger.warning(f"Cannot update provider state: node {target} not found") + return + + node_data = self.graph.graph.nodes[target] + if 'metadata' not in node_data: + node_data['metadata'] = {} + if 'provider_states' not in node_data['metadata']: + node_data['metadata']['provider_states'] = {} + + # Calculate duration safely + try: + duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000 + except Exception: + duration_ms = 0 + + node_data['metadata']['provider_states'][provider_name] = { + 'status': status, + 'timestamp': start_time.isoformat(), + 'results_count': max(0, results_count), # Ensure non-negative + 'error': str(error) if error else None, + 'duration_ms': duration_ms + } + + # Update last modified time for forensic integrity + self.last_modified = datetime.now(timezone.utc).isoformat() + + except Exception as e: + self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}") + def _log_target_processing_error(self, target: str, error: str) -> None: self.logger.logger.error(f"Target processing failed for {target}: {error}") @@ -885,8 +1063,28 @@ class Scanner: self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") def _calculate_progress(self) -> float: - if self.total_tasks_ever_enqueued == 0: return 0.0 - return min(100.0, (self.indicators_completed / self.total_tasks_ever_enqueued) * 100) + try: + if self.total_tasks_ever_enqueued == 0: + return 0.0 + + # Add small buffer for tasks still in queue to avoid showing 100% too early + queue_size = max(0, self.task_queue.qsize()) + with self.processing_lock: + active_tasks = len(self.currently_processing) + + # Adjust total to account for remaining work + adjusted_total = max(self.total_tasks_ever_enqueued, + self.indicators_completed + queue_size + active_tasks) + + if adjusted_total == 0: + return 100.0 + + progress = (self.indicators_completed / adjusted_total) * 100 + return max(0.0, min(100.0, progress)) # Clamp between 0 and 100 + + except Exception as e: + self.logger.logger.warning(f"Error calculating progress: {e}") + return 0.0 def get_graph_data(self) -> Dict[str, Any]: graph_data = self.graph.get_graph_data() diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 54ebb79..6d16009 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -39,6 +39,7 @@ class ShodanProvider(BaseProvider): return False try: response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5) + self.logger.logger.debug("Shodan is reacheable") return response.status_code == 200 except requests.exceptions.RequestException: return False