From ecfb27e02a4fb9ffcc6d60cc5655814d11dda103 Mon Sep 17 00:00:00 2001 From: overcuriousity Date: Wed, 17 Sep 2025 21:47:03 +0200 Subject: [PATCH] new scheduling, removed many debug prints --- core/graph_manager.py | 4 +- core/scanner.py | 154 ++++++++++++++++-------------------- core/session_manager.py | 4 +- providers/crtsh_provider.py | 4 +- providers/dns_provider.py | 4 +- 5 files changed, 74 insertions(+), 96 deletions(-) diff --git a/core/graph_manager.py b/core/graph_manager.py index 5c28c79..fe6028a 100644 --- a/core/graph_manager.py +++ b/core/graph_manager.py @@ -179,7 +179,7 @@ class GraphManager: } self.add_node(correlation_node_id, NodeType.CORRELATION_OBJECT, metadata=metadata) - print(f"Created correlation node {correlation_node_id} for value '{value}' with {len(nodes)} nodes") + #print(f"Created correlation node {correlation_node_id} for value '{value}' with {len(nodes)} nodes") # Create edges from each node to the correlation node for source in sources: @@ -204,7 +204,7 @@ class GraphManager: } ) - print(f"Added correlation edge: {node_id} -> {correlation_node_id} ({relationship_label})") + #print(f"Added correlation edge: {node_id} -> {correlation_node_id} ({relationship_label})") def _has_direct_edge_bidirectional(self, node_a: str, node_b: str) -> bool: diff --git a/core/scanner.py b/core/scanner.py index 9599752..57dc0e5 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -6,6 +6,7 @@ import os import importlib import redis import time +import random # Imported for jitter from typing import List, Set, Dict, Any, Tuple, Optional from concurrent.futures import ThreadPoolExecutor from collections import defaultdict @@ -206,7 +207,6 @@ class Scanner: def _status_logger_thread(self): """Periodically prints a clean, formatted scan status to the terminal.""" - # Color codes for improved display (from Document 2) HEADER = "\033[95m" CYAN = "\033[96m" GREEN = "\033[92m" @@ -218,10 +218,8 @@ class Scanner: last_status_str = "" while not self.status_logger_stop_event.is_set(): try: - # Use thread-safe copy of currently processing with self.processing_lock: in_flight_tasks = list(self.currently_processing) - # Update display list for consistent formatting self.currently_processing_display = in_flight_tasks.copy() status_str = ( @@ -237,23 +235,21 @@ class Scanner: print(f"\n{'-'*80}") print(status_str) if self.last_task_from_queue: - p, pn, ti, d = self.last_task_from_queue + # Unpack the new time-based queue item + _, p, (pn, ti, d) = self.last_task_from_queue print(f"{BLUE}Last task dequeued -> Prio:{p} | Provider:{pn} | Target:'{ti}' | Depth:{d}{ENDC}") if in_flight_tasks: print(f"{BOLD}{YELLOW}Currently Processing:{ENDC}") - # Display up to 3 currently processing tasks display_tasks = [f" - {p}: {t}" for p, t in in_flight_tasks[:3]] print("\n".join(display_tasks)) if len(in_flight_tasks) > 3: print(f" ... and {len(in_flight_tasks) - 3} more") print(f"{'-'*80}") last_status_str = status_str - except Exception: - # Silently fail to avoid crashing the logger pass - time.sleep(2) # Update interval + time.sleep(2) def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool: """ @@ -285,7 +281,8 @@ class Scanner: self.task_queue = PriorityQueue() self.target_retries.clear() self.scan_failed_due_to_retries = False - self.tasks_skipped = 0 # BUGFIX: Reset tasks_skipped for new scan + self.tasks_skipped = 0 + self.last_task_from_queue = None self._update_session_state() @@ -313,7 +310,6 @@ class Scanner: self.tasks_re_enqueued = 0 self.total_tasks_ever_enqueued = 0 self.current_indicator = self.current_target - self.last_task_from_queue = None self._update_session_state() self.logger = new_session() @@ -325,7 +321,6 @@ class Scanner: ) self.scan_thread.start() - # Start the status logger thread self.status_logger_stop_event.clear() self.status_logger_thread = threading.Thread(target=self._status_logger_thread, daemon=True) self.status_logger_thread.start() @@ -348,7 +343,10 @@ class Scanner: return 3 # Lowest priority def _execute_scan(self, target: str, max_depth: int) -> None: - """Execute the reconnaissance scan with proper termination handling.""" + """ + Execute the reconnaissance scan with a time-based, robust scheduler. + Handles rate-limiting via deferral and failures via exponential backoff. + """ self.executor = ThreadPoolExecutor(max_workers=self.max_workers) processed_tasks = set() @@ -356,7 +354,9 @@ class Scanner: initial_providers = self._get_eligible_providers(target, is_ip, False) for provider in initial_providers: provider_name = provider.get_name() - self.task_queue.put((self._get_priority(provider_name), (provider_name, target, 0))) + priority = self._get_priority(provider_name) + # OVERHAUL: Enqueue with current timestamp to run immediately + self.task_queue.put((time.time(), priority, (provider_name, target, 0))) self.total_tasks_ever_enqueued += 1 try: @@ -370,12 +370,24 @@ class Scanner: self.graph.add_node(target, node_type) self._initialize_provider_states(target) - while not self.task_queue.empty() and not self._is_stop_requested(): + while not self._is_stop_requested(): + if self.task_queue.empty() and not self.currently_processing: + break # Scan is complete + try: - priority, (provider_name, target_item, depth) = self.task_queue.get() - self.last_task_from_queue = (priority, provider_name, target_item, depth) + # OVERHAUL: Peek at the next task to see if it's ready to run + next_run_at, _, _ = self.task_queue.queue[0] + if next_run_at > time.time(): + time.sleep(0.1) # Sleep to prevent busy-waiting for future tasks + continue + + # Task is ready, so get it from the queue + run_at, priority, (provider_name, target_item, depth) = self.task_queue.get() + self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) + except IndexError: - break + time.sleep(0.1) # Queue is empty, but tasks might still be processing + continue task_tuple = (provider_name, target_item) if task_tuple in processed_tasks: @@ -385,14 +397,16 @@ class Scanner: if depth > max_depth: continue - + + # OVERHAUL: Handle rate limiting with time-based deferral if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): - self.task_queue.put((priority + 1, (provider_name, target_item, depth))) # Postpone + defer_until = time.time() + 60 # Defer for 60 seconds + self.task_queue.put((defer_until, priority, (provider_name, target_item, depth))) + self.tasks_re_enqueued += 1 continue with self.processing_lock: - if self._is_stop_requested(): - break + if self._is_stop_requested(): break self.currently_processing.add(task_tuple) try: @@ -400,23 +414,24 @@ class Scanner: self.current_indicator = target_item self._update_session_state() - if self._is_stop_requested(): - break + if self._is_stop_requested(): break provider = next((p for p in self.providers if p.get_name() == provider_name), None) if provider: - new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth) + new_targets, _, success = self._query_single_provider_for_target(provider, target_item, depth) - if self._is_stop_requested(): - break + if self._is_stop_requested(): break if not success: self.target_retries[task_tuple] += 1 if self.target_retries[task_tuple] <= self.config.max_retries_per_target: - self.task_queue.put((priority, (provider_name, target_item, depth))) + # OVERHAUL: Exponential backoff for retries + retry_count = self.target_retries[task_tuple] + backoff_delay = (2 ** retry_count) + random.uniform(0, 1) # Add jitter + retry_at = time.time() + backoff_delay + self.task_queue.put((retry_at, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 - #self.total_tasks_ever_enqueued += 1 else: self.scan_failed_due_to_retries = True self._log_target_processing_error(str(task_tuple), "Max retries exceeded") @@ -425,15 +440,16 @@ class Scanner: self.indicators_completed += 1 if not self._is_stop_requested(): - all_new_targets = new_targets - for new_target in all_new_targets: + for new_target in new_targets: is_ip_new = _is_valid_ip(new_target) eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) for p_new in eligible_providers_new: p_name_new = p_new.get_name() if (p_name_new, new_target) not in processed_tasks: new_depth = depth + 1 if new_target in new_targets else depth - self.task_queue.put((self._get_priority(p_name_new), (p_name_new, new_target, new_depth))) + new_priority = self._get_priority(p_name_new) + # OVERHAUL: Enqueue new tasks to run immediately + self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth))) self.total_tasks_ever_enqueued += 1 finally: with self.processing_lock: @@ -455,7 +471,6 @@ class Scanner: else: self.status = ScanStatus.COMPLETED - # Stop the status logger self.status_logger_stop_event.set() if self.status_logger_thread: self.status_logger_thread.join() @@ -665,15 +680,12 @@ class Scanner: self.currently_processing.clear() self.currently_processing_display = [] - discarded_tasks = [] - while not self.task_queue.empty(): - discarded_tasks.append(self.task_queue.get()) self.task_queue = PriorityQueue() if self.executor: try: self.executor.shutdown(wait=False, cancel_futures=True) - except Exception as e: + except Exception: pass self._update_session_state() @@ -722,7 +734,8 @@ class Scanner: eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) for provider in eligible_providers: provider_name = provider.get_name() - self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth))) + priority = self._get_priority(provider_name) + self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth))) self.total_tasks_ever_enqueued += 1 if self.status != ScanStatus.RUNNING: @@ -747,7 +760,7 @@ class Scanner: try: from core.session_manager import session_manager session_manager.update_session_scanner(self.session_id, self) - except Exception as e: + except Exception: pass def get_scan_status(self) -> Dict[str, Any]: @@ -778,66 +791,44 @@ class Scanner: 'tasks_skipped': self.tasks_skipped, 'tasks_rescheduled': self.tasks_re_enqueued, } - except Exception as e: + except Exception: traceback.print_exc() - return { - 'status': 'error', 'target_domain': None, 'current_depth': 0, 'max_depth': 0, - 'current_indicator': '', 'indicators_processed': 0, 'indicators_completed': 0, - 'tasks_re_enqueued': 0, 'progress_percentage': 0.0, 'enabled_providers': [], - 'graph_statistics': {}, 'task_queue_size': 0, 'currently_processing_count': 0, - 'currently_processing': [], 'tasks_in_queue': 0, 'tasks_completed': 0, - 'tasks_skipped': 0, 'tasks_rescheduled': 0, - } + return { 'status': 'error', 'message': 'Failed to get status' } def _initialize_provider_states(self, target: str) -> None: """Initialize provider states for forensic tracking.""" - if not self.graph.graph.has_node(target): - return - + if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] - if 'metadata' not in node_data: - node_data['metadata'] = {} - if 'provider_states' not in node_data['metadata']: - node_data['metadata']['provider_states'] = {} + if 'metadata' not in node_data: node_data['metadata'] = {} + if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List: """Get providers eligible for querying this target.""" if dns_only: return [p for p in self.providers if p.get_name() == 'dns'] - eligible = [] target_key = 'ips' if is_ip else 'domains' - for provider in self.providers: if provider.get_eligibility().get(target_key): if not self._already_queried_provider(target, provider.get_name()): eligible.append(provider) - return eligible def _already_queried_provider(self, target: str, provider_name: str) -> bool: """Check if we already successfully queried a provider for a target.""" - if not self.graph.graph.has_node(target): - return False - + if not self.graph.graph.has_node(target): return False node_data = self.graph.graph.nodes[target] provider_states = node_data.get('metadata', {}).get('provider_states', {}) - provider_state = provider_states.get(provider_name) return provider_state is not None and provider_state.get('status') == 'success' def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: Optional[str], start_time: datetime) -> None: """Update provider state in node metadata for forensic tracking.""" - if not self.graph.graph.has_node(target): - return - + if not self.graph.graph.has_node(target): return node_data = self.graph.graph.nodes[target] - if 'metadata' not in node_data: - node_data['metadata'] = {} - if 'provider_states' not in node_data['metadata']: - node_data['metadata']['provider_states'] = {} - + if 'metadata' not in node_data: node_data['metadata'] = {} + if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {} node_data['metadata']['provider_states'][provider_name] = { 'status': status, 'timestamp': start_time.isoformat(), @@ -847,51 +838,40 @@ class Scanner: } def _log_target_processing_error(self, target: str, error: str) -> None: - """Log target processing errors for forensic trail.""" self.logger.logger.error(f"Target processing failed for {target}: {error}") def _log_provider_error(self, target: str, provider_name: str, error: str) -> None: - """Log provider query errors for forensic trail.""" self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") def _calculate_progress(self) -> float: - """Calculate scan progress percentage based on task completion.""" - if self.total_tasks_ever_enqueued == 0: - return 0.0 + if self.total_tasks_ever_enqueued == 0: return 0.0 return min(100.0, (self.indicators_completed / self.total_tasks_ever_enqueued) * 100) def get_graph_data(self) -> Dict[str, Any]: - """Get current graph data for visualization.""" graph_data = self.graph.get_graph_data() graph_data['initial_targets'] = list(self.initial_targets) return graph_data def export_results(self) -> Dict[str, Any]: - """Export complete scan results with forensic audit trail.""" graph_data = self.graph.export_json() audit_trail = self.logger.export_audit_trail() provider_stats = {} for provider in self.providers: provider_stats[provider.get_name()] = provider.get_statistics() - export_data = { + return { 'scan_metadata': { - 'target_domain': self.current_target, - 'max_depth': self.max_depth, - 'final_status': self.status, - 'total_indicators_processed': self.indicators_processed, - 'enabled_providers': list(provider_stats.keys()), - 'session_id': self.session_id + 'target_domain': self.current_target, 'max_depth': self.max_depth, + 'final_status': self.status, 'total_indicators_processed': self.indicators_processed, + 'enabled_providers': list(provider_stats.keys()), 'session_id': self.session_id }, 'graph_data': graph_data, 'forensic_audit': audit_trail, 'provider_statistics': provider_stats, 'scan_summary': self.logger.get_forensic_summary() } - return export_data def get_provider_info(self) -> Dict[str, Dict[str, Any]]: - """Get information about all available providers.""" info = {} provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') for filename in os.listdir(provider_dir): @@ -905,9 +885,7 @@ class Scanner: provider_class = attribute temp_provider = provider_class(name=attribute_name, session_config=self.config) provider_name = temp_provider.get_name() - live_provider = next((p for p in self.providers if p.get_name() == provider_name), None) - info[provider_name] = { 'display_name': temp_provider.get_display_name(), 'requires_api_key': temp_provider.requires_api_key(), @@ -915,6 +893,6 @@ class Scanner: 'enabled': self.config.is_provider_enabled(provider_name), 'rate_limit': self.config.get_rate_limit(provider_name), } - except Exception as e: + except Exception: traceback.print_exc() return info \ No newline at end of file diff --git a/core/session_manager.py b/core/session_manager.py index 30a1940..a1d916c 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -218,10 +218,10 @@ class SessionManager: # Only log occasionally to reduce noise if hasattr(self, '_last_update_log'): if time.time() - self._last_update_log > 5: # Log every 5 seconds max - print(f"Scanner state updated for session {session_id} (status: {scanner.status})") + #print(f"Scanner state updated for session {session_id} (status: {scanner.status})") self._last_update_log = time.time() else: - print(f"Scanner state updated for session {session_id} (status: {scanner.status})") + #print(f"Scanner state updated for session {session_id} (status: {scanner.status})") self._last_update_log = time.time() else: print(f"WARNING: Failed to save scanner state for session {session_id}") diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index 1e623c0..26e2590 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -115,7 +115,7 @@ class CrtShProvider(BaseProvider): try: if cache_status == "fresh": result = self._load_from_cache(cache_file) - self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") + #self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") else: # "stale" or "not_found" # Query the API for the latest certificates @@ -447,7 +447,7 @@ class CrtShProvider(BaseProvider): return is_not_expired except Exception as e: - self.logger.logger.debug(f"Certificate validity check failed: {e}") + #self.logger.logger.debug(f"Certificate validity check failed: {e}") return False def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: diff --git a/providers/dns_provider.py b/providers/dns_provider.py index 2d03a9b..9ca0e35 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -67,9 +67,9 @@ class DNSProvider(BaseProvider): for record_type in ['A', 'AAAA', 'CNAME', 'MX', 'NS', 'SOA', 'TXT', 'SRV', 'CAA']: try: self._query_record(domain, record_type, result) - except resolver.NoAnswer: + #except resolver.NoAnswer: # This is not an error, just a confirmation that the record doesn't exist. - self.logger.logger.debug(f"No {record_type} record found for {domain}") + #self.logger.logger.debug(f"No {record_type} record found for {domain}") except Exception as e: self.failed_requests += 1 self.logger.logger.debug(f"{record_type} record query failed for {domain}: {e}")