258 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			258 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# dnsrecon/providers/correlation_provider.py
 | 
						|
 | 
						|
import re
 | 
						|
from typing import Dict, Any, List
 | 
						|
 | 
						|
from .base_provider import BaseProvider
 | 
						|
from core.provider_result import ProviderResult
 | 
						|
from core.graph_manager import NodeType, GraphManager
 | 
						|
 | 
						|
class CorrelationProvider(BaseProvider):
 | 
						|
    """
 | 
						|
    A provider that finds correlations between nodes in the graph.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, name: str = "correlation", session_config=None):
 | 
						|
        """
 | 
						|
        Initialize the correlation provider.
 | 
						|
        """
 | 
						|
        super().__init__(name, session_config=session_config)
 | 
						|
        self.graph: GraphManager | None = None
 | 
						|
        self.correlation_index = {}
 | 
						|
        self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
 | 
						|
        self.EXCLUDED_KEYS = [
 | 
						|
            'cert_source',
 | 
						|
            'cert_issuer_ca_id',
 | 
						|
            'cert_common_name',
 | 
						|
            'cert_validity_period_days',
 | 
						|
            'cert_issuer_name',
 | 
						|
            'cert_entry_timestamp',
 | 
						|
            'cert_serial_number', # useless
 | 
						|
            'cert_not_before',
 | 
						|
            'cert_not_after',
 | 
						|
            'dns_ttl',
 | 
						|
            'timestamp',
 | 
						|
            'last_update',
 | 
						|
            'updated_timestamp',
 | 
						|
            'discovery_timestamp',
 | 
						|
            'query_timestamp',
 | 
						|
        ]
 | 
						|
 | 
						|
    def get_name(self) -> str:
 | 
						|
        """Return the provider name."""
 | 
						|
        return "correlation"
 | 
						|
 | 
						|
    def get_display_name(self) -> str:
 | 
						|
        """Return the provider display name for the UI."""
 | 
						|
        return "Correlation Engine"
 | 
						|
 | 
						|
    def requires_api_key(self) -> bool:
 | 
						|
        """Return True if the provider requires an API key."""
 | 
						|
        return False
 | 
						|
 | 
						|
    def get_eligibility(self) -> Dict[str, bool]:
 | 
						|
        """Return a dictionary indicating if the provider can query domains and/or IPs."""
 | 
						|
        return {'domains': True, 'ips': True}
 | 
						|
 | 
						|
    def is_available(self) -> bool:
 | 
						|
        """Check if the provider is available and properly configured."""
 | 
						|
        return True
 | 
						|
 | 
						|
    def query_domain(self, domain: str) -> ProviderResult:
 | 
						|
        """
 | 
						|
        Query the provider for information about a domain.
 | 
						|
        """
 | 
						|
        return self._find_correlations(domain)
 | 
						|
 | 
						|
    def query_ip(self, ip: str) -> ProviderResult:
 | 
						|
        """
 | 
						|
        Query the provider for information about an IP address.
 | 
						|
        """
 | 
						|
        return self._find_correlations(ip)
 | 
						|
 | 
						|
    def set_graph_manager(self, graph_manager: GraphManager):
 | 
						|
        """
 | 
						|
        Set the graph manager for the provider to use.
 | 
						|
        """
 | 
						|
        self.graph = graph_manager
 | 
						|
 | 
						|
    def _find_correlations(self, node_id: str) -> ProviderResult:
 | 
						|
        """
 | 
						|
        Find correlations for a given node with enhanced filtering and error handling.
 | 
						|
        """
 | 
						|
        result = ProviderResult()
 | 
						|
        
 | 
						|
        # Enhanced safety checks
 | 
						|
        if not self.graph or not self.graph.graph.has_node(node_id):
 | 
						|
            return result
 | 
						|
 | 
						|
        try:
 | 
						|
            node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
 | 
						|
            
 | 
						|
            # Ensure attributes is a list (handle legacy data)
 | 
						|
            if not isinstance(node_attributes, list):
 | 
						|
                return result
 | 
						|
                
 | 
						|
            correlations_found = 0
 | 
						|
            
 | 
						|
            for attr in node_attributes:
 | 
						|
                if not isinstance(attr, dict):
 | 
						|
                    continue
 | 
						|
                    
 | 
						|
                attr_name = attr.get('name', '')
 | 
						|
                attr_value = attr.get('value')
 | 
						|
                attr_provider = attr.get('provider', 'unknown')
 | 
						|
 | 
						|
                # Enhanced filtering logic
 | 
						|
                should_exclude = self._should_exclude_attribute(attr_name, attr_value)
 | 
						|
                
 | 
						|
                if should_exclude:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # Build correlation index
 | 
						|
                if attr_value not in self.correlation_index:
 | 
						|
                    self.correlation_index[attr_value] = {
 | 
						|
                        'nodes': set(),
 | 
						|
                        'sources': []
 | 
						|
                    }
 | 
						|
 | 
						|
                self.correlation_index[attr_value]['nodes'].add(node_id)
 | 
						|
 | 
						|
                source_info = {
 | 
						|
                    'node_id': node_id,
 | 
						|
                    'provider': attr_provider,
 | 
						|
                    'attribute': attr_name,
 | 
						|
                    'path': f"{attr_provider}_{attr_name}"
 | 
						|
                }
 | 
						|
 | 
						|
                # Avoid duplicate sources
 | 
						|
                existing_sources = [s for s in self.correlation_index[attr_value]['sources']
 | 
						|
                                if s['node_id'] == node_id and s['path'] == source_info['path']]
 | 
						|
                if not existing_sources:
 | 
						|
                    self.correlation_index[attr_value]['sources'].append(source_info)
 | 
						|
 | 
						|
                # Create correlation if we have multiple nodes with this value
 | 
						|
                if len(self.correlation_index[attr_value]['nodes']) > 1:
 | 
						|
                    self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
 | 
						|
                    correlations_found += 1
 | 
						|
                    
 | 
						|
            # Log correlation results
 | 
						|
            if correlations_found > 0:
 | 
						|
                self.logger.logger.info(f"Found {correlations_found} correlations for node {node_id}")
 | 
						|
                
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.logger.error(f"Error finding correlations for {node_id}: {e}")
 | 
						|
            
 | 
						|
        return result
 | 
						|
    
 | 
						|
    def _should_exclude_attribute(self, attr_name: str, attr_value: Any) -> bool:
 | 
						|
        """
 | 
						|
        Enhanced logic to determine if an attribute should be excluded from correlation.
 | 
						|
        """
 | 
						|
        # Check against excluded keys (exact match or substring)
 | 
						|
        if any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS):
 | 
						|
            return True
 | 
						|
        
 | 
						|
        # Value type filtering
 | 
						|
        if not isinstance(attr_value, (str, int, float, bool)) or attr_value is None:
 | 
						|
            return True
 | 
						|
        
 | 
						|
        # Boolean values are not useful for correlation
 | 
						|
        if isinstance(attr_value, bool):
 | 
						|
            return True
 | 
						|
            
 | 
						|
        # String value filtering
 | 
						|
        if isinstance(attr_value, str):                
 | 
						|
            # Date/timestamp strings
 | 
						|
            if self.date_pattern.match(attr_value):
 | 
						|
                return True
 | 
						|
                
 | 
						|
            # Common non-useful values
 | 
						|
            if attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']:
 | 
						|
                return True
 | 
						|
                
 | 
						|
            # Very long strings that are likely unique (> 100 chars)
 | 
						|
            if len(attr_value) > 100:
 | 
						|
                return True
 | 
						|
        
 | 
						|
        # Numeric value filtering  
 | 
						|
        if isinstance(attr_value, (int, float)):
 | 
						|
            # Very common values
 | 
						|
            if attr_value in [0, 1]:
 | 
						|
                return True
 | 
						|
                
 | 
						|
            # Very large numbers (likely timestamps or unique IDs)
 | 
						|
            if abs(attr_value) > 1000000:
 | 
						|
                return True
 | 
						|
        
 | 
						|
        return False
 | 
						|
    
 | 
						|
    def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
 | 
						|
        """
 | 
						|
        Create correlation relationships with enhanced deduplication and validation.
 | 
						|
        """
 | 
						|
        correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
 | 
						|
        nodes = correlation_data['nodes']
 | 
						|
        sources = correlation_data['sources']
 | 
						|
        
 | 
						|
        # Only create correlations if we have meaningful nodes (more than 1)
 | 
						|
        if len(nodes) < 2:
 | 
						|
            return
 | 
						|
            
 | 
						|
        # Limit correlation size to prevent overly large correlation objects
 | 
						|
        MAX_CORRELATION_SIZE = 50
 | 
						|
        if len(nodes) > MAX_CORRELATION_SIZE:
 | 
						|
            # Sample the nodes to keep correlation manageable
 | 
						|
            import random
 | 
						|
            sampled_nodes = random.sample(list(nodes), MAX_CORRELATION_SIZE)
 | 
						|
            nodes = set(sampled_nodes)
 | 
						|
            # Filter sources to match sampled nodes
 | 
						|
            sources = [s for s in sources if s['node_id'] in nodes]
 | 
						|
 | 
						|
        # Add the correlation node as an attribute to the result
 | 
						|
        result.add_attribute(
 | 
						|
            target_node=correlation_node_id,
 | 
						|
            name="correlation_value",
 | 
						|
            value=value,
 | 
						|
            attr_type=str(type(value).__name__),
 | 
						|
            provider=self.name,
 | 
						|
            confidence=0.9,
 | 
						|
            metadata={
 | 
						|
                'correlated_nodes': list(nodes),
 | 
						|
                'sources': sources,
 | 
						|
                'correlation_size': len(nodes),
 | 
						|
                'value_type': type(value).__name__
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
        # Create relationships with source validation
 | 
						|
        created_relationships = set()
 | 
						|
        
 | 
						|
        for source in sources:
 | 
						|
            node_id = source['node_id']
 | 
						|
            provider = source['provider']
 | 
						|
            attribute = source['attribute']
 | 
						|
            
 | 
						|
            # Skip if we've already created this relationship
 | 
						|
            relationship_key = (node_id, correlation_node_id)
 | 
						|
            if relationship_key in created_relationships:
 | 
						|
                continue
 | 
						|
                
 | 
						|
            relationship_label = f"corr_{provider}_{attribute}"
 | 
						|
 | 
						|
            # Add the relationship to the result
 | 
						|
            result.add_relationship(
 | 
						|
                source_node=node_id,
 | 
						|
                target_node=correlation_node_id,
 | 
						|
                relationship_type=relationship_label,
 | 
						|
                provider=self.name,
 | 
						|
                confidence=0.9,
 | 
						|
                raw_data={
 | 
						|
                    'correlation_value': value,
 | 
						|
                    'original_attribute': attribute,
 | 
						|
                    'correlation_type': 'attribute_matching',
 | 
						|
                    'correlation_size': len(nodes)
 | 
						|
                }
 | 
						|
            )
 | 
						|
            
 | 
						|
            created_relationships.add(relationship_key) |