# DNScope/providers/correlation_provider.py import re from typing import Dict, Any, List from .base_provider import BaseProvider from core.provider_result import ProviderResult from core.graph_manager import NodeType, GraphManager class CorrelationProvider(BaseProvider): """ A provider that finds correlations between nodes in the graph. """ def __init__(self, name: str = "correlation", session_config=None): """ Initialize the correlation provider. """ super().__init__(name, session_config=session_config) self.graph: GraphManager | None = None self.correlation_index = {} self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') self.EXCLUDED_KEYS = [ 'cert_source', 'cert_issuer_ca_id', 'cert_common_name', 'cert_validity_period_days', 'cert_issuer_name', 'cert_entry_timestamp', 'cert_serial_number', # useless 'cert_not_before', 'cert_not_after', 'dns_ttl', 'timestamp', 'last_update', 'updated_timestamp', 'discovery_timestamp', 'query_timestamp', ] def get_name(self) -> str: """Return the provider name.""" return "correlation" def get_display_name(self) -> str: """Return the provider display name for the UI.""" return "Correlation Engine" def requires_api_key(self) -> bool: """Return True if the provider requires an API key.""" return False def get_eligibility(self) -> Dict[str, bool]: """Return a dictionary indicating if the provider can query domains and/or IPs.""" return {'domains': True, 'ips': True} def is_available(self) -> bool: """Check if the provider is available and properly configured.""" return True def query_domain(self, domain: str) -> ProviderResult: """ Query the provider for information about a domain. """ return self._find_correlations(domain) def query_ip(self, ip: str) -> ProviderResult: """ Query the provider for information about an IP address. """ return self._find_correlations(ip) def set_graph_manager(self, graph_manager: GraphManager): """ Set the graph manager for the provider to use. """ self.graph = graph_manager def _find_correlations(self, node_id: str) -> ProviderResult: """ Find correlations for a given node with enhanced filtering and error handling. """ result = ProviderResult() # Enhanced safety checks if not self.graph or not self.graph.graph.has_node(node_id): return result try: node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) # Ensure attributes is a list (handle legacy data) if not isinstance(node_attributes, list): return result correlations_found = 0 for attr in node_attributes: if not isinstance(attr, dict): continue attr_name = attr.get('name', '') attr_value = attr.get('value') attr_provider = attr.get('provider', 'unknown') # Enhanced filtering logic should_exclude = self._should_exclude_attribute(attr_name, attr_value) if should_exclude: continue # Build correlation index if attr_value not in self.correlation_index: self.correlation_index[attr_value] = { 'nodes': set(), 'sources': [] } self.correlation_index[attr_value]['nodes'].add(node_id) source_info = { 'node_id': node_id, 'provider': attr_provider, 'attribute': attr_name, 'path': f"{attr_provider}_{attr_name}" } # Avoid duplicate sources existing_sources = [s for s in self.correlation_index[attr_value]['sources'] if s['node_id'] == node_id and s['path'] == source_info['path']] if not existing_sources: self.correlation_index[attr_value]['sources'].append(source_info) # Create correlation if we have multiple nodes with this value if len(self.correlation_index[attr_value]['nodes']) > 1: self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result) correlations_found += 1 # Log correlation results if correlations_found > 0: self.logger.logger.info(f"Found {correlations_found} correlations for node {node_id}") except Exception as e: self.logger.logger.error(f"Error finding correlations for {node_id}: {e}") return result def _should_exclude_attribute(self, attr_name: str, attr_value: Any) -> bool: """ Enhanced logic to determine if an attribute should be excluded from correlation. """ # Check against excluded keys (exact match or substring) if any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS): return True # Value type filtering if not isinstance(attr_value, (str, int, float, bool)) or attr_value is None: return True # Boolean values are not useful for correlation if isinstance(attr_value, bool): return True # String value filtering if isinstance(attr_value, str): # Date/timestamp strings if self.date_pattern.match(attr_value): return True # Common non-useful values if attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']: return True # Very long strings that are likely unique (> 100 chars) if len(attr_value) > 100: return True # Numeric value filtering if isinstance(attr_value, (int, float)): # Very common values if attr_value in [0, 1]: return True # Very large numbers (likely timestamps or unique IDs) if abs(attr_value) > 1000000: return True return False def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult): """ Create correlation relationships with enhanced deduplication and validation. """ correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}" nodes = correlation_data['nodes'] sources = correlation_data['sources'] # Only create correlations if we have meaningful nodes (more than 1) if len(nodes) < 2: return # Limit correlation size to prevent overly large correlation objects MAX_CORRELATION_SIZE = 50 if len(nodes) > MAX_CORRELATION_SIZE: # Sample the nodes to keep correlation manageable import random sampled_nodes = random.sample(list(nodes), MAX_CORRELATION_SIZE) nodes = set(sampled_nodes) # Filter sources to match sampled nodes sources = [s for s in sources if s['node_id'] in nodes] # Add the correlation node as an attribute to the result result.add_attribute( target_node=correlation_node_id, name="correlation_value", value=value, attr_type=str(type(value).__name__), provider=self.name, confidence=0.9, metadata={ 'correlated_nodes': list(nodes), 'sources': sources, 'correlation_size': len(nodes), 'value_type': type(value).__name__ } ) # Create relationships with source validation created_relationships = set() for source in sources: node_id = source['node_id'] provider = source['provider'] attribute = source['attribute'] # Skip if we've already created this relationship relationship_key = (node_id, correlation_node_id) if relationship_key in created_relationships: continue relationship_label = f"corr_{provider}_{attribute}" # Add the relationship to the result result.add_relationship( source_node=node_id, target_node=correlation_node_id, relationship_type=relationship_label, provider=self.name, confidence=0.9, raw_data={ 'correlation_value': value, 'original_attribute': attribute, 'correlation_type': 'attribute_matching', 'correlation_size': len(nodes) } ) created_relationships.add(relationship_key)