# dnsrecon/providers/correlation_provider.py import re from typing import Dict, Any, List from .base_provider import BaseProvider from core.provider_result import ProviderResult from core.graph_manager import NodeType, GraphManager class CorrelationProvider(BaseProvider): """ A provider that finds correlations between nodes in the graph. """ def __init__(self, name: str = "correlation", session_config=None): """ Initialize the correlation provider. """ super().__init__(name, session_config=session_config) self.graph: GraphManager | None = None self.correlation_index = {} self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') self.EXCLUDED_KEYS = [ 'cert_source', 'cert_issuer_ca_id', 'cert_common_name', 'cert_validity_period_days', 'cert_issuer_name', 'cert_entry_timestamp', 'cert_serial_number', # useless 'cert_not_before', 'cert_not_after', 'dns_ttl', 'timestamp', 'last_update', 'updated_timestamp', 'discovery_timestamp', 'query_timestamp', ] def get_name(self) -> str: """Return the provider name.""" return "correlation" def get_display_name(self) -> str: """Return the provider display name for the UI.""" return "Correlation Engine" def requires_api_key(self) -> bool: """Return True if the provider requires an API key.""" return False def get_eligibility(self) -> Dict[str, bool]: """Return a dictionary indicating if the provider can query domains and/or IPs.""" return {'domains': True, 'ips': True} def is_available(self) -> bool: """Check if the provider is available and properly configured.""" return True def query_domain(self, domain: str) -> ProviderResult: """ Query the provider for information about a domain. """ return self._find_correlations(domain) def query_ip(self, ip: str) -> ProviderResult: """ Query the provider for information about an IP address. """ return self._find_correlations(ip) def set_graph_manager(self, graph_manager: GraphManager): """ Set the graph manager for the provider to use. """ self.graph = graph_manager def _find_correlations(self, node_id: str) -> ProviderResult: """ Find correlations for a given node. """ result = ProviderResult() # FIXED: Ensure self.graph is not None before proceeding. if not self.graph or not self.graph.graph.has_node(node_id): return result node_attributes = self.graph.graph.nodes[node_id].get('attributes', []) for attr in node_attributes: attr_name = attr.get('name') attr_value = attr.get('value') attr_provider = attr.get('provider', 'unknown') should_exclude = ( any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS) or not isinstance(attr_value, (str, int, float, bool)) or attr_value is None or isinstance(attr_value, bool) or (isinstance(attr_value, str) and ( len(attr_value) < 4 or self.date_pattern.match(attr_value) or attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1'] )) or (isinstance(attr_value, (int, float)) and ( attr_value == 0 or attr_value == 1 or abs(attr_value) > 1000000 )) ) if should_exclude: continue if attr_value not in self.correlation_index: self.correlation_index[attr_value] = { 'nodes': set(), 'sources': [] } self.correlation_index[attr_value]['nodes'].add(node_id) source_info = { 'node_id': node_id, 'provider': attr_provider, 'attribute': attr_name, 'path': f"{attr_provider}_{attr_name}" } existing_sources = [s for s in self.correlation_index[attr_value]['sources'] if s['node_id'] == node_id and s['path'] == source_info['path']] if not existing_sources: self.correlation_index[attr_value]['sources'].append(source_info) if len(self.correlation_index[attr_value]['nodes']) > 1: self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result) return result def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult): """ Create correlation relationships and add them to the provider result. """ correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}" nodes = correlation_data['nodes'] sources = correlation_data['sources'] # Add the correlation node as an attribute to the result result.add_attribute( target_node=correlation_node_id, name="correlation_value", value=value, attr_type=str(type(value)), provider=self.name, confidence=0.9, metadata={ 'correlated_nodes': list(nodes), 'sources': sources, } ) for source in sources: node_id = source['node_id'] provider = source['provider'] attribute = source['attribute'] relationship_label = f"corr_{provider}_{attribute}" # Add the relationship to the result result.add_relationship( source_node=node_id, target_node=correlation_node_id, relationship_type=relationship_label, provider=self.name, confidence=0.9, raw_data={ 'correlation_value': value, 'original_attribute': attribute, 'correlation_type': 'attribute_matching' } )