refinements for correlations running logic

This commit is contained in:
overcuriousity
2025-09-20 20:31:56 +02:00
parent 4a82c279ef
commit 602739246f
2 changed files with 206 additions and 69 deletions

View File

@@ -193,6 +193,8 @@ class Scanner:
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
correlation_provider_instance = None
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
@@ -203,7 +205,6 @@ class Scanner:
attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute
# FIXED: Pass the 'name' argument during initialization
provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name()
@@ -224,8 +225,13 @@ class Scanner:
if is_available:
provider.set_stop_event(self.stop_event)
# Special handling for correlation provider
if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph)
correlation_provider_instance = provider
print(f" ✓ Correlation provider configured with graph manager")
self.providers.append(provider)
print(f" ✓ Added to scanner")
else:
@@ -240,6 +246,11 @@ class Scanner:
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
print(f"Active providers: {[p.get_name() for p in self.providers]}")
print(f"Provider count: {len(self.providers)}")
# Verify correlation provider is properly configured
if correlation_provider_instance:
print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}")
print("=" * 50)
def _status_logger_thread(self):
@@ -617,16 +628,21 @@ class Scanner:
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
"""
PHASE 2: Run correlation analysis on all discovered nodes.
This ensures correlations run after all other providers have completed.
Enhanced with better error handling and progress tracking.
"""
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if not correlation_provider:
print("No correlation provider found - skipping correlation phase")
return
# Ensure correlation provider has access to current graph state
correlation_provider.set_graph_manager(self.graph)
print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes")
# Get all nodes from the graph for correlation analysis
all_nodes = list(self.graph.graph.nodes())
correlation_tasks = []
correlation_tasks_enqueued = 0
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
@@ -640,15 +656,22 @@ class Scanner:
priority = self._get_priority('correlation')
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
correlation_tasks.append(task_tuple)
correlation_tasks_enqueued += 1
self.total_tasks_ever_enqueued += 1
print(f"Enqueued {len(correlation_tasks)} correlation tasks")
print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks")
# Process correlation tasks
# Force session state update to reflect new task count
self._update_session_state()
# Process correlation tasks with enhanced tracking
consecutive_empty_iterations = 0
max_empty_iterations = 20 # Shorter timeout for correlation phase
max_empty_iterations = 20
correlation_completed = 0
correlation_errors = 0
while correlation_tasks:
# Check if we should continue processing
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
@@ -656,6 +679,7 @@ class Scanner:
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining")
break
time.sleep(0.1)
continue
@@ -695,26 +719,45 @@ class Scanner:
self.current_indicator = target_item
self._update_session_state()
# Process correlation task
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
if success:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
else:
# For correlations, don't retry - just mark as completed
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
# Process correlation task with enhanced error handling
try:
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
if success:
processed_tasks.add(task_tuple)
correlation_completed += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
else:
# For correlations, don't retry - just mark as completed
correlation_errors += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
except Exception as e:
correlation_errors += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}")
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
# Periodic progress update during correlation phase
if correlation_completed % 10 == 0 and correlation_completed > 0:
remaining = len(correlation_tasks)
print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining")
print(f"Correlation phase complete:")
print(f" - Successfully processed: {correlation_completed}")
print(f" - Errors encountered: {correlation_errors}")
print(f" - Tasks remaining: {len(correlation_tasks)}")
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
"""
@@ -1143,6 +1186,10 @@ class Scanner:
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _calculate_progress(self) -> float:
"""
Enhanced progress calculation that properly accounts for correlation tasks
added during the correlation phase.
"""
try:
if self.total_tasks_ever_enqueued == 0:
return 0.0
@@ -1152,7 +1199,18 @@ class Scanner:
with self.processing_lock:
active_tasks = len(self.currently_processing)
# Adjust total to account for remaining work
# For correlation phase, be more conservative about progress calculation
if self.status == ScanStatus.FINALIZING:
# During correlation phase, show progress more conservatively
base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100
# If we have active correlation tasks, cap progress at 95% until done
if queue_size > 0 or active_tasks > 0:
return min(95.0, base_progress)
else:
return min(100.0, base_progress)
# Normal phase progress calculation
adjusted_total = max(self.total_tasks_ever_enqueued,
self.indicators_completed + queue_size + active_tasks)