refinements for correlations running logic

This commit is contained in:
overcuriousity 2025-09-20 20:31:56 +02:00
parent 4a82c279ef
commit 602739246f
2 changed files with 206 additions and 69 deletions

View File

@ -193,6 +193,8 @@ class Scanner:
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===") print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
correlation_provider_instance = None
for filename in os.listdir(provider_dir): for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'): if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}" module_name = f"providers.{filename[:-3]}"
@ -203,7 +205,6 @@ class Scanner:
attribute = getattr(module, attribute_name) attribute = getattr(module, attribute_name)
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
provider_class = attribute provider_class = attribute
# FIXED: Pass the 'name' argument during initialization
provider = provider_class(name=attribute_name, session_config=self.config) provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name() provider_name = provider.get_name()
@ -224,8 +225,13 @@ class Scanner:
if is_available: if is_available:
provider.set_stop_event(self.stop_event) provider.set_stop_event(self.stop_event)
# Special handling for correlation provider
if isinstance(provider, CorrelationProvider): if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph) provider.set_graph_manager(self.graph)
correlation_provider_instance = provider
print(f" ✓ Correlation provider configured with graph manager")
self.providers.append(provider) self.providers.append(provider)
print(f" ✓ Added to scanner") print(f" ✓ Added to scanner")
else: else:
@ -240,6 +246,11 @@ class Scanner:
print(f"=== PROVIDER INITIALIZATION COMPLETE ===") print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
print(f"Active providers: {[p.get_name() for p in self.providers]}") print(f"Active providers: {[p.get_name() for p in self.providers]}")
print(f"Provider count: {len(self.providers)}") print(f"Provider count: {len(self.providers)}")
# Verify correlation provider is properly configured
if correlation_provider_instance:
print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}")
print("=" * 50) print("=" * 50)
def _status_logger_thread(self): def _status_logger_thread(self):
@ -617,16 +628,21 @@ class Scanner:
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None: def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
""" """
PHASE 2: Run correlation analysis on all discovered nodes. PHASE 2: Run correlation analysis on all discovered nodes.
This ensures correlations run after all other providers have completed. Enhanced with better error handling and progress tracking.
""" """
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None) correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if not correlation_provider: if not correlation_provider:
print("No correlation provider found - skipping correlation phase") print("No correlation provider found - skipping correlation phase")
return return
# Ensure correlation provider has access to current graph state
correlation_provider.set_graph_manager(self.graph)
print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes")
# Get all nodes from the graph for correlation analysis # Get all nodes from the graph for correlation analysis
all_nodes = list(self.graph.graph.nodes()) all_nodes = list(self.graph.graph.nodes())
correlation_tasks = [] correlation_tasks = []
correlation_tasks_enqueued = 0
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes") print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
@ -640,15 +656,22 @@ class Scanner:
priority = self._get_priority('correlation') priority = self._get_priority('correlation')
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth))) self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
correlation_tasks.append(task_tuple) correlation_tasks.append(task_tuple)
correlation_tasks_enqueued += 1
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
print(f"Enqueued {len(correlation_tasks)} correlation tasks") print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks")
# Process correlation tasks # Force session state update to reflect new task count
self._update_session_state()
# Process correlation tasks with enhanced tracking
consecutive_empty_iterations = 0 consecutive_empty_iterations = 0
max_empty_iterations = 20 # Shorter timeout for correlation phase max_empty_iterations = 20
correlation_completed = 0
correlation_errors = 0
while correlation_tasks: while correlation_tasks:
# Check if we should continue processing
queue_empty = self.task_queue.empty() queue_empty = self.task_queue.empty()
with self.processing_lock: with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0 no_active_processing = len(self.currently_processing) == 0
@ -656,6 +679,7 @@ class Scanner:
if queue_empty and no_active_processing: if queue_empty and no_active_processing:
consecutive_empty_iterations += 1 consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations: if consecutive_empty_iterations >= max_empty_iterations:
print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining")
break break
time.sleep(0.1) time.sleep(0.1)
continue continue
@ -695,26 +719,45 @@ class Scanner:
self.current_indicator = target_item self.current_indicator = target_item
self._update_session_state() self._update_session_state()
# Process correlation task # Process correlation task with enhanced error handling
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth) try:
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
if success: if success:
processed_tasks.add(task_tuple) processed_tasks.add(task_tuple)
self.indicators_completed += 1 correlation_completed += 1
if task_tuple in correlation_tasks: self.indicators_completed += 1
correlation_tasks.remove(task_tuple) if task_tuple in correlation_tasks:
else: correlation_tasks.remove(task_tuple)
# For correlations, don't retry - just mark as completed else:
# For correlations, don't retry - just mark as completed
correlation_errors += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
except Exception as e:
correlation_errors += 1
self.indicators_completed += 1 self.indicators_completed += 1
if task_tuple in correlation_tasks: if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple) correlation_tasks.remove(task_tuple)
self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}")
finally: finally:
with self.processing_lock: with self.processing_lock:
processing_key = (provider_name, target_item) processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key) self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}") # Periodic progress update during correlation phase
if correlation_completed % 10 == 0 and correlation_completed > 0:
remaining = len(correlation_tasks)
print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining")
print(f"Correlation phase complete:")
print(f" - Successfully processed: {correlation_completed}")
print(f" - Errors encountered: {correlation_errors}")
print(f" - Tasks remaining: {len(correlation_tasks)}")
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
""" """
@ -1143,6 +1186,10 @@ class Scanner:
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _calculate_progress(self) -> float: def _calculate_progress(self) -> float:
"""
Enhanced progress calculation that properly accounts for correlation tasks
added during the correlation phase.
"""
try: try:
if self.total_tasks_ever_enqueued == 0: if self.total_tasks_ever_enqueued == 0:
return 0.0 return 0.0
@ -1152,7 +1199,18 @@ class Scanner:
with self.processing_lock: with self.processing_lock:
active_tasks = len(self.currently_processing) active_tasks = len(self.currently_processing)
# Adjust total to account for remaining work # For correlation phase, be more conservative about progress calculation
if self.status == ScanStatus.FINALIZING:
# During correlation phase, show progress more conservatively
base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100
# If we have active correlation tasks, cap progress at 95% until done
if queue_size > 0 or active_tasks > 0:
return min(95.0, base_progress)
else:
return min(100.0, base_progress)
# Normal phase progress calculation
adjusted_total = max(self.total_tasks_ever_enqueued, adjusted_total = max(self.total_tasks_ever_enqueued,
self.indicators_completed + queue_size + active_tasks) self.indicators_completed + queue_size + active_tasks)

View File

@ -78,90 +78,166 @@ class CorrelationProvider(BaseProvider):
def _find_correlations(self, node_id: str) -> ProviderResult: def _find_correlations(self, node_id: str) -> ProviderResult:
""" """
Find correlations for a given node. Find correlations for a given node with enhanced filtering and error handling.
""" """
result = ProviderResult() result = ProviderResult()
# FIXED: Ensure self.graph is not None before proceeding.
# Enhanced safety checks
if not self.graph or not self.graph.graph.has_node(node_id): if not self.graph or not self.graph.graph.has_node(node_id):
return result return result
node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) try:
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
for attr in node_attributes: # Ensure attributes is a list (handle legacy data)
attr_name = attr.get('name') if not isinstance(node_attributes, list):
attr_value = attr.get('value') return result
attr_provider = attr.get('provider', 'unknown')
should_exclude = ( correlations_found = 0
any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS) or
not isinstance(attr_value, (str, int, float, bool)) or
attr_value is None or
isinstance(attr_value, bool) or
(isinstance(attr_value, str) and (
len(attr_value) < 4 or
self.date_pattern.match(attr_value) or
attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']
)) or
(isinstance(attr_value, (int, float)) and (
attr_value == 0 or
attr_value == 1 or
abs(attr_value) > 1000000
))
)
if should_exclude: for attr in node_attributes:
continue if not isinstance(attr, dict):
continue
if attr_value not in self.correlation_index: attr_name = attr.get('name', '')
self.correlation_index[attr_value] = { attr_value = attr.get('value')
'nodes': set(), attr_provider = attr.get('provider', 'unknown')
'sources': []
# Enhanced filtering logic
should_exclude = self._should_exclude_attribute(attr_name, attr_value)
if should_exclude:
continue
# Build correlation index
if attr_value not in self.correlation_index:
self.correlation_index[attr_value] = {
'nodes': set(),
'sources': []
}
self.correlation_index[attr_value]['nodes'].add(node_id)
source_info = {
'node_id': node_id,
'provider': attr_provider,
'attribute': attr_name,
'path': f"{attr_provider}_{attr_name}"
} }
self.correlation_index[attr_value]['nodes'].add(node_id) # Avoid duplicate sources
existing_sources = [s for s in self.correlation_index[attr_value]['sources']
if s['node_id'] == node_id and s['path'] == source_info['path']]
if not existing_sources:
self.correlation_index[attr_value]['sources'].append(source_info)
source_info = { # Create correlation if we have multiple nodes with this value
'node_id': node_id, if len(self.correlation_index[attr_value]['nodes']) > 1:
'provider': attr_provider, self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
'attribute': attr_name, correlations_found += 1
'path': f"{attr_provider}_{attr_name}"
}
existing_sources = [s for s in self.correlation_index[attr_value]['sources'] # Log correlation results
if s['node_id'] == node_id and s['path'] == source_info['path']] if correlations_found > 0:
if not existing_sources: self.logger.logger.info(f"Found {correlations_found} correlations for node {node_id}")
self.correlation_index[attr_value]['sources'].append(source_info)
except Exception as e:
self.logger.logger.error(f"Error finding correlations for {node_id}: {e}")
if len(self.correlation_index[attr_value]['nodes']) > 1:
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
return result return result
def _should_exclude_attribute(self, attr_name: str, attr_value: Any) -> bool:
"""
Enhanced logic to determine if an attribute should be excluded from correlation.
"""
# Check against excluded keys (exact match or substring)
if any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS):
return True
# Value type filtering
if not isinstance(attr_value, (str, int, float, bool)) or attr_value is None:
return True
# Boolean values are not useful for correlation
if isinstance(attr_value, bool):
return True
# String value filtering
if isinstance(attr_value, str):
# Date/timestamp strings
if self.date_pattern.match(attr_value):
return True
# Common non-useful values
if attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']:
return True
# Very long strings that are likely unique (> 100 chars)
if len(attr_value) > 100:
return True
# Numeric value filtering
if isinstance(attr_value, (int, float)):
# Very common values
if attr_value in [0, 1]:
return True
# Very large numbers (likely timestamps or unique IDs)
if abs(attr_value) > 1000000:
return True
return False
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult): def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
""" """
Create correlation relationships and add them to the provider result. Create correlation relationships with enhanced deduplication and validation.
""" """
correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}" correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
nodes = correlation_data['nodes'] nodes = correlation_data['nodes']
sources = correlation_data['sources'] sources = correlation_data['sources']
# Only create correlations if we have meaningful nodes (more than 1)
if len(nodes) < 2:
return
# Limit correlation size to prevent overly large correlation objects
MAX_CORRELATION_SIZE = 50
if len(nodes) > MAX_CORRELATION_SIZE:
# Sample the nodes to keep correlation manageable
import random
sampled_nodes = random.sample(list(nodes), MAX_CORRELATION_SIZE)
nodes = set(sampled_nodes)
# Filter sources to match sampled nodes
sources = [s for s in sources if s['node_id'] in nodes]
# Add the correlation node as an attribute to the result # Add the correlation node as an attribute to the result
result.add_attribute( result.add_attribute(
target_node=correlation_node_id, target_node=correlation_node_id,
name="correlation_value", name="correlation_value",
value=value, value=value,
attr_type=str(type(value)), attr_type=str(type(value).__name__),
provider=self.name, provider=self.name,
confidence=0.9, confidence=0.9,
metadata={ metadata={
'correlated_nodes': list(nodes), 'correlated_nodes': list(nodes),
'sources': sources, 'sources': sources,
'correlation_size': len(nodes),
'value_type': type(value).__name__
} }
) )
# Create relationships with source validation
created_relationships = set()
for source in sources: for source in sources:
node_id = source['node_id'] node_id = source['node_id']
provider = source['provider'] provider = source['provider']
attribute = source['attribute'] attribute = source['attribute']
# Skip if we've already created this relationship
relationship_key = (node_id, correlation_node_id)
if relationship_key in created_relationships:
continue
relationship_label = f"corr_{provider}_{attribute}" relationship_label = f"corr_{provider}_{attribute}"
# Add the relationship to the result # Add the relationship to the result
@ -174,6 +250,9 @@ class CorrelationProvider(BaseProvider):
raw_data={ raw_data={
'correlation_value': value, 'correlation_value': value,
'original_attribute': attribute, 'original_attribute': attribute,
'correlation_type': 'attribute_matching' 'correlation_type': 'attribute_matching',
'correlation_size': len(nodes)
} }
) )
created_relationships.add(relationship_key)