refinements for correlations running logic
This commit is contained in:
		
							parent
							
								
									4a82c279ef
								
							
						
					
					
						commit
						602739246f
					
				@ -193,6 +193,8 @@ class Scanner:
 | 
			
		||||
        
 | 
			
		||||
        print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
 | 
			
		||||
        
 | 
			
		||||
        correlation_provider_instance = None
 | 
			
		||||
        
 | 
			
		||||
        for filename in os.listdir(provider_dir):
 | 
			
		||||
            if filename.endswith('_provider.py') and not filename.startswith('base'):
 | 
			
		||||
                module_name = f"providers.{filename[:-3]}"
 | 
			
		||||
@ -203,7 +205,6 @@ class Scanner:
 | 
			
		||||
                        attribute = getattr(module, attribute_name)
 | 
			
		||||
                        if isinstance(attribute, type) and issubclass(attribute, BaseProvider) and attribute is not BaseProvider:
 | 
			
		||||
                            provider_class = attribute
 | 
			
		||||
                            # FIXED: Pass the 'name' argument during initialization
 | 
			
		||||
                            provider = provider_class(name=attribute_name, session_config=self.config)
 | 
			
		||||
                            provider_name = provider.get_name()
 | 
			
		||||
 | 
			
		||||
@ -224,8 +225,13 @@ class Scanner:
 | 
			
		||||
                                
 | 
			
		||||
                                if is_available:
 | 
			
		||||
                                    provider.set_stop_event(self.stop_event)
 | 
			
		||||
                                    
 | 
			
		||||
                                    # Special handling for correlation provider
 | 
			
		||||
                                    if isinstance(provider, CorrelationProvider):
 | 
			
		||||
                                        provider.set_graph_manager(self.graph)
 | 
			
		||||
                                        correlation_provider_instance = provider
 | 
			
		||||
                                        print(f"    ✓ Correlation provider configured with graph manager")
 | 
			
		||||
                                    
 | 
			
		||||
                                    self.providers.append(provider)
 | 
			
		||||
                                    print(f"    ✓ Added to scanner")
 | 
			
		||||
                                else:
 | 
			
		||||
@ -240,6 +246,11 @@ class Scanner:
 | 
			
		||||
        print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
 | 
			
		||||
        print(f"Active providers: {[p.get_name() for p in self.providers]}")
 | 
			
		||||
        print(f"Provider count: {len(self.providers)}")
 | 
			
		||||
        
 | 
			
		||||
        # Verify correlation provider is properly configured
 | 
			
		||||
        if correlation_provider_instance:
 | 
			
		||||
            print(f"Correlation provider configured: {correlation_provider_instance.graph is not None}")
 | 
			
		||||
        
 | 
			
		||||
        print("=" * 50)
 | 
			
		||||
 | 
			
		||||
    def _status_logger_thread(self):
 | 
			
		||||
@ -617,16 +628,21 @@ class Scanner:
 | 
			
		||||
    def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        PHASE 2: Run correlation analysis on all discovered nodes.
 | 
			
		||||
        This ensures correlations run after all other providers have completed.
 | 
			
		||||
        Enhanced with better error handling and progress tracking.
 | 
			
		||||
        """
 | 
			
		||||
        correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
 | 
			
		||||
        if not correlation_provider:
 | 
			
		||||
            print("No correlation provider found - skipping correlation phase")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Ensure correlation provider has access to current graph state
 | 
			
		||||
        correlation_provider.set_graph_manager(self.graph)
 | 
			
		||||
        print(f"Correlation provider configured with graph containing {self.graph.get_node_count()} nodes")
 | 
			
		||||
 | 
			
		||||
        # Get all nodes from the graph for correlation analysis
 | 
			
		||||
        all_nodes = list(self.graph.graph.nodes())
 | 
			
		||||
        correlation_tasks = []
 | 
			
		||||
        correlation_tasks_enqueued = 0
 | 
			
		||||
        
 | 
			
		||||
        print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
 | 
			
		||||
        
 | 
			
		||||
@ -640,15 +656,22 @@ class Scanner:
 | 
			
		||||
                priority = self._get_priority('correlation')
 | 
			
		||||
                self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
 | 
			
		||||
                correlation_tasks.append(task_tuple)
 | 
			
		||||
                correlation_tasks_enqueued += 1
 | 
			
		||||
                self.total_tasks_ever_enqueued += 1
 | 
			
		||||
        
 | 
			
		||||
        print(f"Enqueued {len(correlation_tasks)} correlation tasks")
 | 
			
		||||
        print(f"Enqueued {correlation_tasks_enqueued} new correlation tasks")
 | 
			
		||||
        
 | 
			
		||||
        # Process correlation tasks
 | 
			
		||||
        # Force session state update to reflect new task count
 | 
			
		||||
        self._update_session_state()
 | 
			
		||||
        
 | 
			
		||||
        # Process correlation tasks with enhanced tracking
 | 
			
		||||
        consecutive_empty_iterations = 0
 | 
			
		||||
        max_empty_iterations = 20  # Shorter timeout for correlation phase
 | 
			
		||||
        max_empty_iterations = 20
 | 
			
		||||
        correlation_completed = 0
 | 
			
		||||
        correlation_errors = 0
 | 
			
		||||
        
 | 
			
		||||
        while correlation_tasks:
 | 
			
		||||
            # Check if we should continue processing
 | 
			
		||||
            queue_empty = self.task_queue.empty()
 | 
			
		||||
            with self.processing_lock:
 | 
			
		||||
                no_active_processing = len(self.currently_processing) == 0
 | 
			
		||||
@ -656,6 +679,7 @@ class Scanner:
 | 
			
		||||
            if queue_empty and no_active_processing:
 | 
			
		||||
                consecutive_empty_iterations += 1
 | 
			
		||||
                if consecutive_empty_iterations >= max_empty_iterations:
 | 
			
		||||
                    print(f"Correlation phase timeout - {len(correlation_tasks)} tasks remaining")
 | 
			
		||||
                    break
 | 
			
		||||
                time.sleep(0.1)
 | 
			
		||||
                continue
 | 
			
		||||
@ -695,26 +719,45 @@ class Scanner:
 | 
			
		||||
                self.current_indicator = target_item
 | 
			
		||||
                self._update_session_state()
 | 
			
		||||
                
 | 
			
		||||
                # Process correlation task
 | 
			
		||||
                new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
 | 
			
		||||
                # Process correlation task with enhanced error handling
 | 
			
		||||
                try:
 | 
			
		||||
                    new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
 | 
			
		||||
                    
 | 
			
		||||
                if success:
 | 
			
		||||
                    processed_tasks.add(task_tuple)
 | 
			
		||||
                    self.indicators_completed += 1
 | 
			
		||||
                    if task_tuple in correlation_tasks:
 | 
			
		||||
                        correlation_tasks.remove(task_tuple)
 | 
			
		||||
                else:
 | 
			
		||||
                    # For correlations, don't retry - just mark as completed
 | 
			
		||||
                    if success:
 | 
			
		||||
                        processed_tasks.add(task_tuple)
 | 
			
		||||
                        correlation_completed += 1
 | 
			
		||||
                        self.indicators_completed += 1
 | 
			
		||||
                        if task_tuple in correlation_tasks:
 | 
			
		||||
                            correlation_tasks.remove(task_tuple)
 | 
			
		||||
                    else:
 | 
			
		||||
                        # For correlations, don't retry - just mark as completed
 | 
			
		||||
                        correlation_errors += 1
 | 
			
		||||
                        self.indicators_completed += 1
 | 
			
		||||
                        if task_tuple in correlation_tasks:
 | 
			
		||||
                            correlation_tasks.remove(task_tuple)
 | 
			
		||||
                            
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    correlation_errors += 1
 | 
			
		||||
                    self.indicators_completed += 1
 | 
			
		||||
                    if task_tuple in correlation_tasks:
 | 
			
		||||
                        correlation_tasks.remove(task_tuple)
 | 
			
		||||
                    self.logger.logger.warning(f"Correlation task failed for {target_item}: {e}")
 | 
			
		||||
                        
 | 
			
		||||
            finally:
 | 
			
		||||
                with self.processing_lock:
 | 
			
		||||
                    processing_key = (provider_name, target_item)
 | 
			
		||||
                    self.currently_processing.discard(processing_key)
 | 
			
		||||
 | 
			
		||||
        print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
 | 
			
		||||
            # Periodic progress update during correlation phase
 | 
			
		||||
            if correlation_completed % 10 == 0 and correlation_completed > 0:
 | 
			
		||||
                remaining = len(correlation_tasks)
 | 
			
		||||
                print(f"Correlation progress: {correlation_completed} completed, {remaining} remaining")
 | 
			
		||||
 | 
			
		||||
        print(f"Correlation phase complete:")
 | 
			
		||||
        print(f"  - Successfully processed: {correlation_completed}")
 | 
			
		||||
        print(f"  - Errors encountered: {correlation_errors}")
 | 
			
		||||
        print(f"  - Tasks remaining: {len(correlation_tasks)}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
 | 
			
		||||
        """
 | 
			
		||||
@ -1143,6 +1186,10 @@ class Scanner:
 | 
			
		||||
        self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
 | 
			
		||||
 | 
			
		||||
    def _calculate_progress(self) -> float:
 | 
			
		||||
        """
 | 
			
		||||
        Enhanced progress calculation that properly accounts for correlation tasks
 | 
			
		||||
        added during the correlation phase.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            if self.total_tasks_ever_enqueued == 0:
 | 
			
		||||
                return 0.0
 | 
			
		||||
@ -1152,7 +1199,18 @@ class Scanner:
 | 
			
		||||
            with self.processing_lock:
 | 
			
		||||
                active_tasks = len(self.currently_processing)
 | 
			
		||||
            
 | 
			
		||||
            # Adjust total to account for remaining work
 | 
			
		||||
            # For correlation phase, be more conservative about progress calculation
 | 
			
		||||
            if self.status == ScanStatus.FINALIZING:
 | 
			
		||||
                # During correlation phase, show progress more conservatively
 | 
			
		||||
                base_progress = (self.indicators_completed / max(self.total_tasks_ever_enqueued, 1)) * 100
 | 
			
		||||
                
 | 
			
		||||
                # If we have active correlation tasks, cap progress at 95% until done
 | 
			
		||||
                if queue_size > 0 or active_tasks > 0:
 | 
			
		||||
                    return min(95.0, base_progress)
 | 
			
		||||
                else:
 | 
			
		||||
                    return min(100.0, base_progress)
 | 
			
		||||
            
 | 
			
		||||
            # Normal phase progress calculation
 | 
			
		||||
            adjusted_total = max(self.total_tasks_ever_enqueued, 
 | 
			
		||||
                            self.indicators_completed + queue_size + active_tasks)
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
@ -78,90 +78,166 @@ class CorrelationProvider(BaseProvider):
 | 
			
		||||
 | 
			
		||||
    def _find_correlations(self, node_id: str) -> ProviderResult:
 | 
			
		||||
        """
 | 
			
		||||
        Find correlations for a given node.
 | 
			
		||||
        Find correlations for a given node with enhanced filtering and error handling.
 | 
			
		||||
        """
 | 
			
		||||
        result = ProviderResult()
 | 
			
		||||
        # FIXED: Ensure self.graph is not None before proceeding.
 | 
			
		||||
        
 | 
			
		||||
        # Enhanced safety checks
 | 
			
		||||
        if not self.graph or not self.graph.graph.has_node(node_id):
 | 
			
		||||
            return result
 | 
			
		||||
 | 
			
		||||
        node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
 | 
			
		||||
        try:
 | 
			
		||||
            node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
 | 
			
		||||
            
 | 
			
		||||
        for attr in node_attributes:
 | 
			
		||||
            attr_name = attr.get('name')
 | 
			
		||||
            attr_value = attr.get('value')
 | 
			
		||||
            attr_provider = attr.get('provider', 'unknown')
 | 
			
		||||
            # Ensure attributes is a list (handle legacy data)
 | 
			
		||||
            if not isinstance(node_attributes, list):
 | 
			
		||||
                return result
 | 
			
		||||
                
 | 
			
		||||
            should_exclude = (
 | 
			
		||||
                any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS) or
 | 
			
		||||
                not isinstance(attr_value, (str, int, float, bool)) or
 | 
			
		||||
                attr_value is None or
 | 
			
		||||
                isinstance(attr_value, bool) or
 | 
			
		||||
                (isinstance(attr_value, str) and (
 | 
			
		||||
                    len(attr_value) < 4 or
 | 
			
		||||
                    self.date_pattern.match(attr_value) or
 | 
			
		||||
                    attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']
 | 
			
		||||
                )) or
 | 
			
		||||
                (isinstance(attr_value, (int, float)) and (
 | 
			
		||||
                    attr_value == 0 or
 | 
			
		||||
                    attr_value == 1 or
 | 
			
		||||
                    abs(attr_value) > 1000000
 | 
			
		||||
                ))
 | 
			
		||||
            )
 | 
			
		||||
            correlations_found = 0
 | 
			
		||||
            
 | 
			
		||||
            if should_exclude:
 | 
			
		||||
                continue
 | 
			
		||||
            for attr in node_attributes:
 | 
			
		||||
                if not isinstance(attr, dict):
 | 
			
		||||
                    continue
 | 
			
		||||
                    
 | 
			
		||||
            if attr_value not in self.correlation_index:
 | 
			
		||||
                self.correlation_index[attr_value] = {
 | 
			
		||||
                    'nodes': set(),
 | 
			
		||||
                    'sources': []
 | 
			
		||||
                attr_name = attr.get('name', '')
 | 
			
		||||
                attr_value = attr.get('value')
 | 
			
		||||
                attr_provider = attr.get('provider', 'unknown')
 | 
			
		||||
 | 
			
		||||
                # Enhanced filtering logic
 | 
			
		||||
                should_exclude = self._should_exclude_attribute(attr_name, attr_value)
 | 
			
		||||
                
 | 
			
		||||
                if should_exclude:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # Build correlation index
 | 
			
		||||
                if attr_value not in self.correlation_index:
 | 
			
		||||
                    self.correlation_index[attr_value] = {
 | 
			
		||||
                        'nodes': set(),
 | 
			
		||||
                        'sources': []
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                self.correlation_index[attr_value]['nodes'].add(node_id)
 | 
			
		||||
 | 
			
		||||
                source_info = {
 | 
			
		||||
                    'node_id': node_id,
 | 
			
		||||
                    'provider': attr_provider,
 | 
			
		||||
                    'attribute': attr_name,
 | 
			
		||||
                    'path': f"{attr_provider}_{attr_name}"
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            self.correlation_index[attr_value]['nodes'].add(node_id)
 | 
			
		||||
                # Avoid duplicate sources
 | 
			
		||||
                existing_sources = [s for s in self.correlation_index[attr_value]['sources']
 | 
			
		||||
                                if s['node_id'] == node_id and s['path'] == source_info['path']]
 | 
			
		||||
                if not existing_sources:
 | 
			
		||||
                    self.correlation_index[attr_value]['sources'].append(source_info)
 | 
			
		||||
 | 
			
		||||
            source_info = {
 | 
			
		||||
                'node_id': node_id,
 | 
			
		||||
                'provider': attr_provider,
 | 
			
		||||
                'attribute': attr_name,
 | 
			
		||||
                'path': f"{attr_provider}_{attr_name}"
 | 
			
		||||
            }
 | 
			
		||||
                # Create correlation if we have multiple nodes with this value
 | 
			
		||||
                if len(self.correlation_index[attr_value]['nodes']) > 1:
 | 
			
		||||
                    self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
 | 
			
		||||
                    correlations_found += 1
 | 
			
		||||
                    
 | 
			
		||||
            existing_sources = [s for s in self.correlation_index[attr_value]['sources']
 | 
			
		||||
                              if s['node_id'] == node_id and s['path'] == source_info['path']]
 | 
			
		||||
            if not existing_sources:
 | 
			
		||||
                self.correlation_index[attr_value]['sources'].append(source_info)
 | 
			
		||||
            # Log correlation results
 | 
			
		||||
            if correlations_found > 0:
 | 
			
		||||
                self.logger.logger.info(f"Found {correlations_found} correlations for node {node_id}")
 | 
			
		||||
                
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            self.logger.logger.error(f"Error finding correlations for {node_id}: {e}")
 | 
			
		||||
            
 | 
			
		||||
            if len(self.correlation_index[attr_value]['nodes']) > 1:
 | 
			
		||||
                self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
 | 
			
		||||
        return result
 | 
			
		||||
    
 | 
			
		||||
    def _should_exclude_attribute(self, attr_name: str, attr_value: Any) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Enhanced logic to determine if an attribute should be excluded from correlation.
 | 
			
		||||
        """
 | 
			
		||||
        # Check against excluded keys (exact match or substring)
 | 
			
		||||
        if any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS):
 | 
			
		||||
            return True
 | 
			
		||||
        
 | 
			
		||||
        # Value type filtering
 | 
			
		||||
        if not isinstance(attr_value, (str, int, float, bool)) or attr_value is None:
 | 
			
		||||
            return True
 | 
			
		||||
        
 | 
			
		||||
        # Boolean values are not useful for correlation
 | 
			
		||||
        if isinstance(attr_value, bool):
 | 
			
		||||
            return True
 | 
			
		||||
            
 | 
			
		||||
        # String value filtering
 | 
			
		||||
        if isinstance(attr_value, str):                
 | 
			
		||||
            # Date/timestamp strings
 | 
			
		||||
            if self.date_pattern.match(attr_value):
 | 
			
		||||
                return True
 | 
			
		||||
                
 | 
			
		||||
            # Common non-useful values
 | 
			
		||||
            if attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']:
 | 
			
		||||
                return True
 | 
			
		||||
                
 | 
			
		||||
            # Very long strings that are likely unique (> 100 chars)
 | 
			
		||||
            if len(attr_value) > 100:
 | 
			
		||||
                return True
 | 
			
		||||
        
 | 
			
		||||
        # Numeric value filtering  
 | 
			
		||||
        if isinstance(attr_value, (int, float)):
 | 
			
		||||
            # Very common values
 | 
			
		||||
            if attr_value in [0, 1]:
 | 
			
		||||
                return True
 | 
			
		||||
                
 | 
			
		||||
            # Very large numbers (likely timestamps or unique IDs)
 | 
			
		||||
            if abs(attr_value) > 1000000:
 | 
			
		||||
                return True
 | 
			
		||||
        
 | 
			
		||||
        return False
 | 
			
		||||
    
 | 
			
		||||
    def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
 | 
			
		||||
        """
 | 
			
		||||
        Create correlation relationships and add them to the provider result.
 | 
			
		||||
        Create correlation relationships with enhanced deduplication and validation.
 | 
			
		||||
        """
 | 
			
		||||
        correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
 | 
			
		||||
        nodes = correlation_data['nodes']
 | 
			
		||||
        sources = correlation_data['sources']
 | 
			
		||||
        
 | 
			
		||||
        # Only create correlations if we have meaningful nodes (more than 1)
 | 
			
		||||
        if len(nodes) < 2:
 | 
			
		||||
            return
 | 
			
		||||
            
 | 
			
		||||
        # Limit correlation size to prevent overly large correlation objects
 | 
			
		||||
        MAX_CORRELATION_SIZE = 50
 | 
			
		||||
        if len(nodes) > MAX_CORRELATION_SIZE:
 | 
			
		||||
            # Sample the nodes to keep correlation manageable
 | 
			
		||||
            import random
 | 
			
		||||
            sampled_nodes = random.sample(list(nodes), MAX_CORRELATION_SIZE)
 | 
			
		||||
            nodes = set(sampled_nodes)
 | 
			
		||||
            # Filter sources to match sampled nodes
 | 
			
		||||
            sources = [s for s in sources if s['node_id'] in nodes]
 | 
			
		||||
 | 
			
		||||
        # Add the correlation node as an attribute to the result
 | 
			
		||||
        result.add_attribute(
 | 
			
		||||
            target_node=correlation_node_id,
 | 
			
		||||
            name="correlation_value",
 | 
			
		||||
            value=value,
 | 
			
		||||
            attr_type=str(type(value)),
 | 
			
		||||
            attr_type=str(type(value).__name__),
 | 
			
		||||
            provider=self.name,
 | 
			
		||||
            confidence=0.9,
 | 
			
		||||
            metadata={
 | 
			
		||||
                'correlated_nodes': list(nodes),
 | 
			
		||||
                'sources': sources,
 | 
			
		||||
                'correlation_size': len(nodes),
 | 
			
		||||
                'value_type': type(value).__name__
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Create relationships with source validation
 | 
			
		||||
        created_relationships = set()
 | 
			
		||||
        
 | 
			
		||||
        for source in sources:
 | 
			
		||||
            node_id = source['node_id']
 | 
			
		||||
            provider = source['provider']
 | 
			
		||||
            attribute = source['attribute']
 | 
			
		||||
            
 | 
			
		||||
            # Skip if we've already created this relationship
 | 
			
		||||
            relationship_key = (node_id, correlation_node_id)
 | 
			
		||||
            if relationship_key in created_relationships:
 | 
			
		||||
                continue
 | 
			
		||||
                
 | 
			
		||||
            relationship_label = f"corr_{provider}_{attribute}"
 | 
			
		||||
 | 
			
		||||
            # Add the relationship to the result
 | 
			
		||||
@ -174,6 +250,9 @@ class CorrelationProvider(BaseProvider):
 | 
			
		||||
                raw_data={
 | 
			
		||||
                    'correlation_value': value,
 | 
			
		||||
                    'original_attribute': attribute,
 | 
			
		||||
                    'correlation_type': 'attribute_matching'
 | 
			
		||||
                    'correlation_type': 'attribute_matching',
 | 
			
		||||
                    'correlation_size': len(nodes)
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            
 | 
			
		||||
            created_relationships.add(relationship_key)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user