diff --git a/core/scanner.py b/core/scanner.py index 1bf75cc..5ebc1a4 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -193,6 +193,8 @@ class Scanner: print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===") + correlation_provider_instance = None + for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" @@ -203,7 +205,6 @@ class Scanner: attribute = getattr(module, attribute_name) if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider: provider_class = attribute - # FIXED: Pass the 'name' argument during initialization provider = provider_class(name=attribute_name, session_config=self.config) provider_name = provider.get_name() @@ -224,8 +225,13 @@ class Scanner: if is_available: provider.set_stop_event(self.stop_event) + + # Special handling for correlation provider if isinstance(provider, CorrelationProvider): provider.set_graph_manager(self.graph) + correlation_provider_instance = provider + print(f" ✓ Correlation provider configured with graph manager") + self.providers.append(provider) print(f" ✓ Added to scanner") else: @@ -240,6 +246,11 @@ class Scanner: print(f"=== PROVIDER INITIALIZATION COMPLETE ===") print(f"Active providers: {[p.get_name() for p in self.providers]}") print(f"Provider count: {len(self.providers)}") + + # Verify correlation provider is properly configured + if correlation_provider_instance: + print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}") + print("=" * 50) def _status_logger_thread(self): @@ -617,16 +628,21 @@ class Scanner: def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None: """ PHASE 2: Run correlation analysis on all discovered nodes. - This ensures correlations run after all other providers have completed. + Enhanced with better error handling and progress tracking. """ correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None) if not correlation_provider: print("No correlation provider found - skipping correlation phase") return + # Ensure correlation provider has access to current graph state + correlation_provider.set_graph_manager(self.graph) + print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes") + # Get all nodes from the graph for correlation analysis all_nodes = list(self.graph.graph.nodes()) correlation_tasks = [] + correlation_tasks_enqueued = 0 print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes") @@ -640,15 +656,22 @@ class Scanner: priority = self._get_priority('correlation') self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth))) correlation_tasks.append(task_tuple) + correlation_tasks_enqueued += 1 self.total_tasks_ever_enqueued += 1 - print(f"Enqueued {len(correlation_tasks)} correlation tasks") + print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks") - # Process correlation tasks + # Force session state update to reflect new task count + self._update_session_state() + + # Process correlation tasks with enhanced tracking consecutive_empty_iterations = 0 - max_empty_iterations = 20 # Shorter timeout for correlation phase + max_empty_iterations = 20 + correlation_completed = 0 + correlation_errors = 0 while correlation_tasks: + # Check if we should continue processing queue_empty = self.task_queue.empty() with self.processing_lock: no_active_processing = len(self.currently_processing) == 0 @@ -656,6 +679,7 @@ class Scanner: if queue_empty and no_active_processing: consecutive_empty_iterations += 1 if consecutive_empty_iterations >= max_empty_iterations: + print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining") break time.sleep(0.1) continue @@ -695,26 +719,45 @@ class Scanner: self.current_indicator = target_item self._update_session_state() - # Process correlation task - new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth) - - if success: - processed_tasks.add(task_tuple) - self.indicators_completed += 1 - if task_tuple in correlation_tasks: - correlation_tasks.remove(task_tuple) - else: - # For correlations, don't retry - just mark as completed - self.indicators_completed += 1 - if task_tuple in correlation_tasks: - correlation_tasks.remove(task_tuple) + # Process correlation task with enhanced error handling + try: + new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth) + if success: + processed_tasks.add(task_tuple) + correlation_completed += 1 + self.indicators_completed += 1 + if task_tuple in correlation_tasks: + correlation_tasks.remove(task_tuple) + else: + # For correlations, don't retry - just mark as completed + correlation_errors += 1 + self.indicators_completed += 1 + if task_tuple in correlation_tasks: + correlation_tasks.remove(task_tuple) + + except Exception as e: + correlation_errors += 1 + self.indicators_completed += 1 + if task_tuple in correlation_tasks: + correlation_tasks.remove(task_tuple) + self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}") + finally: with self.processing_lock: processing_key = (provider_name, target_item) self.currently_processing.discard(processing_key) - print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}") + # Periodic progress update during correlation phase + if correlation_completed % 10 == 0 and correlation_completed > 0: + remaining = len(correlation_tasks) + print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining") + + print(f"Correlation phase complete:") + print(f" - Successfully processed: {correlation_completed}") + print(f" - Errors encountered: {correlation_errors}") + print(f" - Tasks remaining: {len(correlation_tasks)}") + def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: """ @@ -1143,6 +1186,10 @@ class Scanner: self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}") def _calculate_progress(self) -> float: + """ + Enhanced progress calculation that properly accounts for correlation tasks + added during the correlation phase. + """ try: if self.total_tasks_ever_enqueued == 0: return 0.0 @@ -1152,7 +1199,18 @@ class Scanner: with self.processing_lock: active_tasks = len(self.currently_processing) - # Adjust total to account for remaining work + # For correlation phase, be more conservative about progress calculation + if self.status == ScanStatus.FINALIZING: + # During correlation phase, show progress more conservatively + base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100 + + # If we have active correlation tasks, cap progress at 95% until done + if queue_size > 0 or active_tasks > 0: + return min(95.0, base_progress) + else: + return min(100.0, base_progress) + + # Normal phase progress calculation adjusted_total = max(self.total_tasks_ever_enqueued, self.indicators_completed + queue_size + active_tasks) diff --git a/providers/correlation_provider.py b/providers/correlation_provider.py index 9ae6eeb..6a91fb2 100644 --- a/providers/correlation_provider.py +++ b/providers/correlation_provider.py @@ -78,90 +78,166 @@ class CorrelationProvider(BaseProvider): def _find_correlations(self, node_id: str) -> ProviderResult: """ - Find correlations for a given node. + Find correlations for a given node with enhanced filtering and error handling. """ result = ProviderResult() - # FIXED: Ensure self.graph is not None before proceeding. + + # Enhanced safety checks if not self.graph or not self.graph.graph.has_node(node_id): return result - node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) + try: + node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) + + # Ensure attributes is a list (handle legacy data) + if not isinstance(node_attributes, list): + return result + + correlations_found = 0 + + for attr in node_attributes: + if not isinstance(attr, dict): + continue + + attr_name = attr.get('name', '') + attr_value = attr.get('value') + attr_provider = attr.get('provider', 'unknown') - for attr in node_attributes: - attr_name = attr.get('name') - attr_value = attr.get('value') - attr_provider = attr.get('provider', 'unknown') + # Enhanced filtering logic + should_exclude = self._should_exclude_attribute(attr_name, attr_value) + + if should_exclude: + continue - should_exclude = ( - any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS) or - not isinstance(attr_value, (str, int, float, bool)) or - attr_value is None or - isinstance(attr_value, bool) or - (isinstance(attr_value, str) and ( - len(attr_value) < 4 or - self.date_pattern.match(attr_value) or - attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1'] - )) or - (isinstance(attr_value, (int, float)) and ( - attr_value == 0 or - attr_value == 1 or - abs(attr_value) > 1000000 - )) - ) + # Build correlation index + if attr_value not in self.correlation_index: + self.correlation_index[attr_value] = { + 'nodes': set(), + 'sources': [] + } - if should_exclude: - continue + self.correlation_index[attr_value]['nodes'].add(node_id) - if attr_value not in self.correlation_index: - self.correlation_index[attr_value] = { - 'nodes': set(), - 'sources': [] + source_info = { + 'node_id': node_id, + 'provider': attr_provider, + 'attribute': attr_name, + 'path': f"{attr_provider}_{attr_name}" } - self.correlation_index[attr_value]['nodes'].add(node_id) + # Avoid duplicate sources + existing_sources = [s for s in self.correlation_index[attr_value]['sources'] + if s['node_id'] == node_id and s['path'] == source_info['path']] + if not existing_sources: + self.correlation_index[attr_value]['sources'].append(source_info) - source_info = { - 'node_id': node_id, - 'provider': attr_provider, - 'attribute': attr_name, - 'path': f"{attr_provider}_{attr_name}" - } - - existing_sources = [s for s in self.correlation_index[attr_value]['sources'] - if s['node_id'] == node_id and s['path'] == source_info['path']] - if not existing_sources: - self.correlation_index[attr_value]['sources'].append(source_info) - - if len(self.correlation_index[attr_value]['nodes']) > 1: - self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result) + # Create correlation if we have multiple nodes with this value + if len(self.correlation_index[attr_value]['nodes']) > 1: + self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result) + correlations_found += 1 + + # Log correlation results + if correlations_found > 0: + self.logger.logger.info(f"Found {correlations_found} correlations for node {node_id}") + + except Exception as e: + self.logger.logger.error(f"Error finding correlations for {node_id}: {e}") + return result - + + def _should_exclude_attribute(self, attr_name: str, attr_value: Any) -> bool: + """ + Enhanced logic to determine if an attribute should be excluded from correlation. + """ + # Check against excluded keys (exact match or substring) + if any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS): + return True + + # Value type filtering + if not isinstance(attr_value, (str, int, float, bool)) or attr_value is None: + return True + + # Boolean values are not useful for correlation + if isinstance(attr_value, bool): + return True + + # String value filtering + if isinstance(attr_value, str): + # Date/timestamp strings + if self.date_pattern.match(attr_value): + return True + + # Common non-useful values + if attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']: + return True + + # Very long strings that are likely unique (> 100 chars) + if len(attr_value) > 100: + return True + + # Numeric value filtering + if isinstance(attr_value, (int, float)): + # Very common values + if attr_value in [0, 1]: + return True + + # Very large numbers (likely timestamps or unique IDs) + if abs(attr_value) > 1000000: + return True + + return False + def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult): """ - Create correlation relationships and add them to the provider result. + Create correlation relationships with enhanced deduplication and validation. """ correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}" nodes = correlation_data['nodes'] sources = correlation_data['sources'] + + # Only create correlations if we have meaningful nodes (more than 1) + if len(nodes) < 2: + return + + # Limit correlation size to prevent overly large correlation objects + MAX_CORRELATION_SIZE = 50 + if len(nodes) > MAX_CORRELATION_SIZE: + # Sample the nodes to keep correlation manageable + import random + sampled_nodes = random.sample(list(nodes), MAX_CORRELATION_SIZE) + nodes = set(sampled_nodes) + # Filter sources to match sampled nodes + sources = [s for s in sources if s['node_id'] in nodes] # Add the correlation node as an attribute to the result result.add_attribute( target_node=correlation_node_id, name="correlation_value", value=value, - attr_type=str(type(value)), + attr_type=str(type(value).__name__), provider=self.name, confidence=0.9, metadata={ 'correlated_nodes': list(nodes), 'sources': sources, + 'correlation_size': len(nodes), + 'value_type': type(value).__name__ } ) + # Create relationships with source validation + created_relationships = set() + for source in sources: node_id = source['node_id'] provider = source['provider'] attribute = source['attribute'] + + # Skip if we've already created this relationship + relationship_key = (node_id, correlation_node_id) + if relationship_key in created_relationships: + continue + relationship_label = f"corr_{provider}_{attribute}" # Add the relationship to the result @@ -174,6 +250,9 @@ class CorrelationProvider(BaseProvider): raw_data={ 'correlation_value': value, 'original_attribute': attribute, - 'correlation_type': 'attribute_matching' + 'correlation_type': 'attribute_matching', + 'correlation_size': len(nodes) } - ) \ No newline at end of file + ) + + created_relationships.add(relationship_key) \ No newline at end of file