fix running of correlations after stop

This commit is contained in:
overcuriousity 2025-09-20 20:16:23 +02:00
parent 71a05f5b32
commit 4a82c279ef

View File

@ -27,6 +27,7 @@ class ScanStatus:
"""Enumeration of scan statuses.""" """Enumeration of scan statuses."""
IDLE = "idle" IDLE = "idle"
RUNNING = "running" RUNNING = "running"
FINALIZING = "finalizing" # New state for post-scan analysis
COMPLETED = "completed" COMPLETED = "completed"
FAILED = "failed" FAILED = "failed"
STOPPED = "stopped" STOPPED = "stopped"
@ -450,12 +451,10 @@ class Scanner:
def _execute_scan(self, target: str, max_depth: int) -> None: def _execute_scan(self, target: str, max_depth: int) -> None:
self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping processed_tasks = set()
is_ip = _is_valid_ip(target) is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False) initial_providers = [p for p in self._get_eligible_providers(target, is_ip, False) if not isinstance(p, CorrelationProvider)]
# FIXED: Filter out correlation provider from initial providers
initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)]
for provider in initial_providers: for provider in initial_providers:
provider_name = provider.get_name() provider_name = provider.get_name()
@ -474,9 +473,8 @@ class Scanner:
self.graph.add_node(target, node_type) self.graph.add_node(target, node_type)
self._initialize_provider_states(target) self._initialize_provider_states(target)
consecutive_empty_iterations = 0 consecutive_empty_iterations = 0
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion max_empty_iterations = 50
# PHASE 1: Run all non-correlation providers
print(f"\n=== PHASE 1: Running non-correlation providers ===") print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested(): while not self._is_stop_requested():
queue_empty = self.task_queue.empty() queue_empty = self.task_queue.empty()
@ -486,57 +484,39 @@ class Scanner:
if queue_empty and no_active_processing: if queue_empty and no_active_processing:
consecutive_empty_iterations += 1 consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations: if consecutive_empty_iterations >= max_empty_iterations:
break # Phase 1 complete break
time.sleep(0.1) time.sleep(0.1)
continue continue
else: else:
consecutive_empty_iterations = 0 consecutive_empty_iterations = 0
# Process tasks (same logic as before, but correlations are filtered out)
try: try:
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
if provider_name == 'correlation': continue
# Skip correlation tasks during Phase 1
if provider_name == 'correlation':
continue
# Check if task is ready to run
current_time = time.time() current_time = time.time()
if run_at > current_time: if run_at > current_time:
self.task_queue.put((run_at, priority, (provider_name, target_item, depth))) self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
time.sleep(min(0.5, run_at - current_time)) time.sleep(min(0.5, run_at - current_time))
continue continue
except:
except: # Queue is empty or timeout occurred
time.sleep(0.1) time.sleep(0.1)
continue continue
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
# Skip if already processed
task_tuple = (provider_name, target_item, depth) task_tuple = (provider_name, target_item, depth)
if task_tuple in processed_tasks: if task_tuple in processed_tasks or depth > max_depth:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# Skip if depth exceeded
if depth > max_depth:
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
continue continue
# Rate limiting with proper time-based deferral
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
defer_until = time.time() + 60 defer_until = time.time() + 60
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth))) self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1 self.tasks_re_enqueued += 1
continue continue
# Thread-safe processing state management
with self.processing_lock: with self.processing_lock:
if self._is_stop_requested(): if self._is_stop_requested(): break
break
processing_key = (provider_name, target_item) processing_key = (provider_name, target_item)
if processing_key in self.currently_processing: if processing_key in self.currently_processing:
self.tasks_skipped += 1 self.tasks_skipped += 1
@ -548,29 +528,21 @@ class Scanner:
self.current_depth = depth self.current_depth = depth
self.current_indicator = target_item self.current_indicator = target_item
self._update_session_state() self._update_session_state()
if self._is_stop_requested(): break
if self._is_stop_requested():
break
provider = next((p for p in self.providers if p.get_name() == provider_name), None) provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider and not isinstance(provider, CorrelationProvider): if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth) new_targets, _, success = self._process_provider_task(provider, target_item, depth)
if self._is_stop_requested(): break
if self._is_stop_requested():
break
if not success: if not success:
retry_key = (provider_name, target_item, depth) retry_key = (provider_name, target_item, depth)
self.target_retries[retry_key] += 1 self.target_retries[retry_key] += 1
if self.target_retries[retry_key] <= self.config.max_retries_per_target: if self.target_retries[retry_key] <= self.config.max_retries_per_target:
retry_count = self.target_retries[retry_key] retry_count = self.target_retries[retry_key]
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
retry_at = time.time() + backoff_delay self.task_queue.put((time.time() + backoff_delay, priority, (provider_name, target_item, depth)))
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1 self.tasks_re_enqueued += 1
self.logger.logger.debug(f"Retrying {provider_name}:{target_item} in {backoff_delay:.1f}s (attempt {retry_count})")
else: else:
self.scan_failed_due_to_retries = True self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded") self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded")
@ -578,49 +550,33 @@ class Scanner:
processed_tasks.add(task_tuple) processed_tasks.add(task_tuple)
self.indicators_completed += 1 self.indicators_completed += 1
# Enqueue new targets with proper depth tracking
if not self._is_stop_requested(): if not self._is_stop_requested():
for new_target in new_targets: for new_target in new_targets:
is_ip_new = _is_valid_ip(new_target) is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) eligible_providers_new = [p for p in self._get_eligible_providers(new_target, is_ip_new, False) if not isinstance(p, CorrelationProvider)]
# FIXED: Filter out correlation providers in Phase 1
eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)]
for p_new in eligible_providers_new: for p_new in eligible_providers_new:
p_name_new = p_new.get_name() p_name_new = p_new.get_name()
new_depth = depth + 1 new_depth = depth + 1
new_task_tuple = (p_name_new, new_target, new_depth) if (p_name_new, new_target, new_depth) not in processed_tasks and new_depth <= max_depth:
self.task_queue.put((time.time(), self._get_priority(p_name_new), (p_name_new, new_target, new_depth)))
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
new_priority = self._get_priority(p_name_new)
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
else: else:
self.logger.logger.warning(f"Provider {provider_name} not found in active providers")
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
finally: finally:
with self.processing_lock: with self.processing_lock:
processing_key = (provider_name, target_item) self.currently_processing.discard((provider_name, target_item))
self.currently_processing.discard(processing_key)
# This code runs after the main loop finishes or is stopped.
self.status = ScanStatus.FINALIZING
self._update_session_state()
self.logger.logger.info("Scan stopped or completed. Entering finalization phase.")
except Exception as e: if self.status in [ScanStatus.FINALIZING, ScanStatus.COMPLETED, ScanStatus.STOPPED]:
traceback.print_exc() print(f"\n=== PHASE 2: Running correlation analysis ===")
self.status = ScanStatus.FAILED self._run_correlation_phase(max_depth, processed_tasks)
self.logger.logger.error(f"Scan failed: {e}")
finally:
# Comprehensive cleanup (same as before)
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
# Determine the final status *after* finalization.
if self._is_stop_requested(): if self._is_stop_requested():
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED
elif self.scan_failed_due_to_retries: elif self.scan_failed_due_to_retries:
@ -628,17 +584,25 @@ class Scanner:
else: else:
self.status = ScanStatus.COMPLETED self.status = ScanStatus.COMPLETED
if self.status in [ScanStatus.COMPLETED, ScanStatus.STOPPED]: except Exception as e:
print(f"\n=== PHASE 2: Running correlation analysis ===") traceback.print_exc()
self._run_correlation_phase(max_depth, processed_tasks) self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
# The 'finally' block is now only for guaranteed cleanup.
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
while not self.task_queue.empty():
try: self.task_queue.get_nowait()
except: break
self.status_logger_stop_event.set() self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive(): if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0) self.status_logger_thread.join(timeout=2.0)
self._update_session_state()
self.logger.log_scan_complete()
# The executor shutdown now happens *after* the correlation phase has run.
if self.executor: if self.executor:
try: try:
self.executor.shutdown(wait=False, cancel_futures=True) self.executor.shutdown(wait=False, cancel_futures=True)
@ -646,6 +610,9 @@ class Scanner:
self.logger.logger.warning(f"Error shutting down executor: {e}") self.logger.logger.warning(f"Error shutting down executor: {e}")
finally: finally:
self.executor = None self.executor = None
self._update_session_state()
self.logger.log_scan_complete()
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None: def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
""" """