refinements for correlations running logic
This commit is contained in:
parent
4a82c279ef
commit
602739246f
@ -193,6 +193,8 @@ class Scanner:
|
|||||||
|
|
||||||
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
|
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
|
||||||
|
|
||||||
|
correlation_provider_instance = None
|
||||||
|
|
||||||
for filename in os.listdir(provider_dir):
|
for filename in os.listdir(provider_dir):
|
||||||
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
||||||
module_name = f"providers.{filename[:-3]}"
|
module_name = f"providers.{filename[:-3]}"
|
||||||
@ -203,7 +205,6 @@ class Scanner:
|
|||||||
attribute = getattr(module, attribute_name)
|
attribute = getattr(module, attribute_name)
|
||||||
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
|
||||||
provider_class = attribute
|
provider_class = attribute
|
||||||
# FIXED: Pass the 'name' argument during initialization
|
|
||||||
provider = provider_class(name=attribute_name, session_config=self.config)
|
provider = provider_class(name=attribute_name, session_config=self.config)
|
||||||
provider_name = provider.get_name()
|
provider_name = provider.get_name()
|
||||||
|
|
||||||
@ -224,8 +225,13 @@ class Scanner:
|
|||||||
|
|
||||||
if is_available:
|
if is_available:
|
||||||
provider.set_stop_event(self.stop_event)
|
provider.set_stop_event(self.stop_event)
|
||||||
|
|
||||||
|
# Special handling for correlation provider
|
||||||
if isinstance(provider, CorrelationProvider):
|
if isinstance(provider, CorrelationProvider):
|
||||||
provider.set_graph_manager(self.graph)
|
provider.set_graph_manager(self.graph)
|
||||||
|
correlation_provider_instance = provider
|
||||||
|
print(f" ✓ Correlation provider configured with graph manager")
|
||||||
|
|
||||||
self.providers.append(provider)
|
self.providers.append(provider)
|
||||||
print(f" ✓ Added to scanner")
|
print(f" ✓ Added to scanner")
|
||||||
else:
|
else:
|
||||||
@ -240,6 +246,11 @@ class Scanner:
|
|||||||
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
|
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
|
||||||
print(f"Active providers: {[p.get_name() for p in self.providers]}")
|
print(f"Active providers: {[p.get_name() for p in self.providers]}")
|
||||||
print(f"Provider count: {len(self.providers)}")
|
print(f"Provider count: {len(self.providers)}")
|
||||||
|
|
||||||
|
# Verify correlation provider is properly configured
|
||||||
|
if correlation_provider_instance:
|
||||||
|
print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}")
|
||||||
|
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
def _status_logger_thread(self):
|
def _status_logger_thread(self):
|
||||||
@ -617,16 +628,21 @@ class Scanner:
|
|||||||
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
|
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
|
||||||
"""
|
"""
|
||||||
PHASE 2: Run correlation analysis on all discovered nodes.
|
PHASE 2: Run correlation analysis on all discovered nodes.
|
||||||
This ensures correlations run after all other providers have completed.
|
Enhanced with better error handling and progress tracking.
|
||||||
"""
|
"""
|
||||||
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
|
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
|
||||||
if not correlation_provider:
|
if not correlation_provider:
|
||||||
print("No correlation provider found - skipping correlation phase")
|
print("No correlation provider found - skipping correlation phase")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Ensure correlation provider has access to current graph state
|
||||||
|
correlation_provider.set_graph_manager(self.graph)
|
||||||
|
print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes")
|
||||||
|
|
||||||
# Get all nodes from the graph for correlation analysis
|
# Get all nodes from the graph for correlation analysis
|
||||||
all_nodes = list(self.graph.graph.nodes())
|
all_nodes = list(self.graph.graph.nodes())
|
||||||
correlation_tasks = []
|
correlation_tasks = []
|
||||||
|
correlation_tasks_enqueued = 0
|
||||||
|
|
||||||
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
|
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
|
||||||
|
|
||||||
@ -640,15 +656,22 @@ class Scanner:
|
|||||||
priority = self._get_priority('correlation')
|
priority = self._get_priority('correlation')
|
||||||
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
|
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
|
||||||
correlation_tasks.append(task_tuple)
|
correlation_tasks.append(task_tuple)
|
||||||
|
correlation_tasks_enqueued += 1
|
||||||
self.total_tasks_ever_enqueued += 1
|
self.total_tasks_ever_enqueued += 1
|
||||||
|
|
||||||
print(f"Enqueued {len(correlation_tasks)} correlation tasks")
|
print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks")
|
||||||
|
|
||||||
# Process correlation tasks
|
# Force session state update to reflect new task count
|
||||||
|
self._update_session_state()
|
||||||
|
|
||||||
|
# Process correlation tasks with enhanced tracking
|
||||||
consecutive_empty_iterations = 0
|
consecutive_empty_iterations = 0
|
||||||
max_empty_iterations = 20 # Shorter timeout for correlation phase
|
max_empty_iterations = 20
|
||||||
|
correlation_completed = 0
|
||||||
|
correlation_errors = 0
|
||||||
|
|
||||||
while correlation_tasks:
|
while correlation_tasks:
|
||||||
|
# Check if we should continue processing
|
||||||
queue_empty = self.task_queue.empty()
|
queue_empty = self.task_queue.empty()
|
||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
no_active_processing = len(self.currently_processing) == 0
|
no_active_processing = len(self.currently_processing) == 0
|
||||||
@ -656,6 +679,7 @@ class Scanner:
|
|||||||
if queue_empty and no_active_processing:
|
if queue_empty and no_active_processing:
|
||||||
consecutive_empty_iterations += 1
|
consecutive_empty_iterations += 1
|
||||||
if consecutive_empty_iterations >= max_empty_iterations:
|
if consecutive_empty_iterations >= max_empty_iterations:
|
||||||
|
print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining")
|
||||||
break
|
break
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
@ -695,26 +719,45 @@ class Scanner:
|
|||||||
self.current_indicator = target_item
|
self.current_indicator = target_item
|
||||||
self._update_session_state()
|
self._update_session_state()
|
||||||
|
|
||||||
# Process correlation task
|
# Process correlation task with enhanced error handling
|
||||||
|
try:
|
||||||
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
|
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
processed_tasks.add(task_tuple)
|
processed_tasks.add(task_tuple)
|
||||||
|
correlation_completed += 1
|
||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
if task_tuple in correlation_tasks:
|
if task_tuple in correlation_tasks:
|
||||||
correlation_tasks.remove(task_tuple)
|
correlation_tasks.remove(task_tuple)
|
||||||
else:
|
else:
|
||||||
# For correlations, don't retry - just mark as completed
|
# For correlations, don't retry - just mark as completed
|
||||||
|
correlation_errors += 1
|
||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
if task_tuple in correlation_tasks:
|
if task_tuple in correlation_tasks:
|
||||||
correlation_tasks.remove(task_tuple)
|
correlation_tasks.remove(task_tuple)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
correlation_errors += 1
|
||||||
|
self.indicators_completed += 1
|
||||||
|
if task_tuple in correlation_tasks:
|
||||||
|
correlation_tasks.remove(task_tuple)
|
||||||
|
self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
processing_key = (provider_name, target_item)
|
processing_key = (provider_name, target_item)
|
||||||
self.currently_processing.discard(processing_key)
|
self.currently_processing.discard(processing_key)
|
||||||
|
|
||||||
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
|
# Periodic progress update during correlation phase
|
||||||
|
if correlation_completed % 10 == 0 and correlation_completed > 0:
|
||||||
|
remaining = len(correlation_tasks)
|
||||||
|
print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining")
|
||||||
|
|
||||||
|
print(f"Correlation phase complete:")
|
||||||
|
print(f" - Successfully processed: {correlation_completed}")
|
||||||
|
print(f" - Errors encountered: {correlation_errors}")
|
||||||
|
print(f" - Tasks remaining: {len(correlation_tasks)}")
|
||||||
|
|
||||||
|
|
||||||
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
||||||
"""
|
"""
|
||||||
@ -1143,6 +1186,10 @@ class Scanner:
|
|||||||
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
|
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
|
||||||
|
|
||||||
def _calculate_progress(self) -> float:
|
def _calculate_progress(self) -> float:
|
||||||
|
"""
|
||||||
|
Enhanced progress calculation that properly accounts for correlation tasks
|
||||||
|
added during the correlation phase.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if self.total_tasks_ever_enqueued == 0:
|
if self.total_tasks_ever_enqueued == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
@ -1152,7 +1199,18 @@ class Scanner:
|
|||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
active_tasks = len(self.currently_processing)
|
active_tasks = len(self.currently_processing)
|
||||||
|
|
||||||
# Adjust total to account for remaining work
|
# For correlation phase, be more conservative about progress calculation
|
||||||
|
if self.status == ScanStatus.FINALIZING:
|
||||||
|
# During correlation phase, show progress more conservatively
|
||||||
|
base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100
|
||||||
|
|
||||||
|
# If we have active correlation tasks, cap progress at 95% until done
|
||||||
|
if queue_size > 0 or active_tasks > 0:
|
||||||
|
return min(95.0, base_progress)
|
||||||
|
else:
|
||||||
|
return min(100.0, base_progress)
|
||||||
|
|
||||||
|
# Normal phase progress calculation
|
||||||
adjusted_total = max(self.total_tasks_ever_enqueued,
|
adjusted_total = max(self.total_tasks_ever_enqueued,
|
||||||
self.indicators_completed + queue_size + active_tasks)
|
self.indicators_completed + queue_size + active_tasks)
|
||||||
|
|
||||||
|
|||||||
@ -78,40 +78,38 @@ class CorrelationProvider(BaseProvider):
|
|||||||
|
|
||||||
def _find_correlations(self, node_id: str) -> ProviderResult:
|
def _find_correlations(self, node_id: str) -> ProviderResult:
|
||||||
"""
|
"""
|
||||||
Find correlations for a given node.
|
Find correlations for a given node with enhanced filtering and error handling.
|
||||||
"""
|
"""
|
||||||
result = ProviderResult()
|
result = ProviderResult()
|
||||||
# FIXED: Ensure self.graph is not None before proceeding.
|
|
||||||
|
# Enhanced safety checks
|
||||||
if not self.graph or not self.graph.graph.has_node(node_id):
|
if not self.graph or not self.graph.graph.has_node(node_id):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
||||||
|
|
||||||
|
# Ensure attributes is a list (handle legacy data)
|
||||||
|
if not isinstance(node_attributes, list):
|
||||||
|
return result
|
||||||
|
|
||||||
|
correlations_found = 0
|
||||||
|
|
||||||
for attr in node_attributes:
|
for attr in node_attributes:
|
||||||
attr_name = attr.get('name')
|
if not isinstance(attr, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
attr_name = attr.get('name', '')
|
||||||
attr_value = attr.get('value')
|
attr_value = attr.get('value')
|
||||||
attr_provider = attr.get('provider', 'unknown')
|
attr_provider = attr.get('provider', 'unknown')
|
||||||
|
|
||||||
should_exclude = (
|
# Enhanced filtering logic
|
||||||
any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS) or
|
should_exclude = self._should_exclude_attribute(attr_name, attr_value)
|
||||||
not isinstance(attr_value, (str, int, float, bool)) or
|
|
||||||
attr_value is None or
|
|
||||||
isinstance(attr_value, bool) or
|
|
||||||
(isinstance(attr_value, str) and (
|
|
||||||
len(attr_value) < 4 or
|
|
||||||
self.date_pattern.match(attr_value) or
|
|
||||||
attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']
|
|
||||||
)) or
|
|
||||||
(isinstance(attr_value, (int, float)) and (
|
|
||||||
attr_value == 0 or
|
|
||||||
attr_value == 1 or
|
|
||||||
abs(attr_value) > 1000000
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_exclude:
|
if should_exclude:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Build correlation index
|
||||||
if attr_value not in self.correlation_index:
|
if attr_value not in self.correlation_index:
|
||||||
self.correlation_index[attr_value] = {
|
self.correlation_index[attr_value] = {
|
||||||
'nodes': set(),
|
'nodes': set(),
|
||||||
@ -127,41 +125,119 @@ class CorrelationProvider(BaseProvider):
|
|||||||
'path': f"{attr_provider}_{attr_name}"
|
'path': f"{attr_provider}_{attr_name}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Avoid duplicate sources
|
||||||
existing_sources = [s for s in self.correlation_index[attr_value]['sources']
|
existing_sources = [s for s in self.correlation_index[attr_value]['sources']
|
||||||
if s['node_id'] == node_id and s['path'] == source_info['path']]
|
if s['node_id'] == node_id and s['path'] == source_info['path']]
|
||||||
if not existing_sources:
|
if not existing_sources:
|
||||||
self.correlation_index[attr_value]['sources'].append(source_info)
|
self.correlation_index[attr_value]['sources'].append(source_info)
|
||||||
|
|
||||||
|
# Create correlation if we have multiple nodes with this value
|
||||||
if len(self.correlation_index[attr_value]['nodes']) > 1:
|
if len(self.correlation_index[attr_value]['nodes']) > 1:
|
||||||
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
|
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
|
||||||
|
correlations_found += 1
|
||||||
|
|
||||||
|
# Log correlation results
|
||||||
|
if correlations_found > 0:
|
||||||
|
self.logger.logger.info(f"Found {correlations_found} correlations for node {node_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.logger.error(f"Error finding correlations for {node_id}: {e}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _should_exclude_attribute(self, attr_name: str, attr_value: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Enhanced logic to determine if an attribute should be excluded from correlation.
|
||||||
|
"""
|
||||||
|
# Check against excluded keys (exact match or substring)
|
||||||
|
if any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Value type filtering
|
||||||
|
if not isinstance(attr_value, (str, int, float, bool)) or attr_value is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Boolean values are not useful for correlation
|
||||||
|
if isinstance(attr_value, bool):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# String value filtering
|
||||||
|
if isinstance(attr_value, str):
|
||||||
|
# Date/timestamp strings
|
||||||
|
if self.date_pattern.match(attr_value):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Common non-useful values
|
||||||
|
if attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Very long strings that are likely unique (> 100 chars)
|
||||||
|
if len(attr_value) > 100:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Numeric value filtering
|
||||||
|
if isinstance(attr_value, (int, float)):
|
||||||
|
# Very common values
|
||||||
|
if attr_value in [0, 1]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Very large numbers (likely timestamps or unique IDs)
|
||||||
|
if abs(attr_value) > 1000000:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
|
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
|
||||||
"""
|
"""
|
||||||
Create correlation relationships and add them to the provider result.
|
Create correlation relationships with enhanced deduplication and validation.
|
||||||
"""
|
"""
|
||||||
correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
|
correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
|
||||||
nodes = correlation_data['nodes']
|
nodes = correlation_data['nodes']
|
||||||
sources = correlation_data['sources']
|
sources = correlation_data['sources']
|
||||||
|
|
||||||
|
# Only create correlations if we have meaningful nodes (more than 1)
|
||||||
|
if len(nodes) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Limit correlation size to prevent overly large correlation objects
|
||||||
|
MAX_CORRELATION_SIZE = 50
|
||||||
|
if len(nodes) > MAX_CORRELATION_SIZE:
|
||||||
|
# Sample the nodes to keep correlation manageable
|
||||||
|
import random
|
||||||
|
sampled_nodes = random.sample(list(nodes), MAX_CORRELATION_SIZE)
|
||||||
|
nodes = set(sampled_nodes)
|
||||||
|
# Filter sources to match sampled nodes
|
||||||
|
sources = [s for s in sources if s['node_id'] in nodes]
|
||||||
|
|
||||||
# Add the correlation node as an attribute to the result
|
# Add the correlation node as an attribute to the result
|
||||||
result.add_attribute(
|
result.add_attribute(
|
||||||
target_node=correlation_node_id,
|
target_node=correlation_node_id,
|
||||||
name="correlation_value",
|
name="correlation_value",
|
||||||
value=value,
|
value=value,
|
||||||
attr_type=str(type(value)),
|
attr_type=str(type(value).__name__),
|
||||||
provider=self.name,
|
provider=self.name,
|
||||||
confidence=0.9,
|
confidence=0.9,
|
||||||
metadata={
|
metadata={
|
||||||
'correlated_nodes': list(nodes),
|
'correlated_nodes': list(nodes),
|
||||||
'sources': sources,
|
'sources': sources,
|
||||||
|
'correlation_size': len(nodes),
|
||||||
|
'value_type': type(value).__name__
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create relationships with source validation
|
||||||
|
created_relationships = set()
|
||||||
|
|
||||||
for source in sources:
|
for source in sources:
|
||||||
node_id = source['node_id']
|
node_id = source['node_id']
|
||||||
provider = source['provider']
|
provider = source['provider']
|
||||||
attribute = source['attribute']
|
attribute = source['attribute']
|
||||||
|
|
||||||
|
# Skip if we've already created this relationship
|
||||||
|
relationship_key = (node_id, correlation_node_id)
|
||||||
|
if relationship_key in created_relationships:
|
||||||
|
continue
|
||||||
|
|
||||||
relationship_label = f"corr_{provider}_{attribute}"
|
relationship_label = f"corr_{provider}_{attribute}"
|
||||||
|
|
||||||
# Add the relationship to the result
|
# Add the relationship to the result
|
||||||
@ -174,6 +250,9 @@ class CorrelationProvider(BaseProvider):
|
|||||||
raw_data={
|
raw_data={
|
||||||
'correlation_value': value,
|
'correlation_value': value,
|
||||||
'original_attribute': attribute,
|
'original_attribute': attribute,
|
||||||
'correlation_type': 'attribute_matching'
|
'correlation_type': 'attribute_matching',
|
||||||
|
'correlation_size': len(nodes)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
created_relationships.add(relationship_key)
|
||||||
Loading…
x
Reference in New Issue
Block a user