178 lines
6.3 KiB
Python
178 lines
6.3 KiB
Python
# dnsrecon/providers/correlation_provider.py
|
|
|
|
import re
|
|
from typing import Dict, Any, List
|
|
|
|
from .base_provider import BaseProvider
|
|
from core.provider_result import ProviderResult
|
|
from core.graph_manager import NodeType, GraphManager
|
|
|
|
class CorrelationProvider(BaseProvider):
|
|
"""
|
|
A provider that finds correlations between nodes in the graph.
|
|
"""
|
|
|
|
def __init__(self, name: str = "correlation", session_config=None):
|
|
"""
|
|
Initialize the correlation provider.
|
|
"""
|
|
super().__init__(name, session_config=session_config)
|
|
self.graph: GraphManager | None = None
|
|
self.correlation_index = {}
|
|
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
|
|
self.EXCLUDED_KEYS = [
|
|
'cert_source',
|
|
'cert_issuer_ca_id',
|
|
'cert_common_name',
|
|
'cert_validity_period_days',
|
|
'cert_issuer_name',
|
|
'cert_entry_timestamp',
|
|
'cert_not_before',
|
|
'cert_not_after',
|
|
'dns_ttl',
|
|
'timestamp',
|
|
'last_update',
|
|
'updated_timestamp',
|
|
'discovery_timestamp',
|
|
'query_timestamp',
|
|
]
|
|
|
|
def get_name(self) -> str:
|
|
"""Return the provider name."""
|
|
return "correlation"
|
|
|
|
def get_display_name(self) -> str:
|
|
"""Return the provider display name for the UI."""
|
|
return "Correlation Engine"
|
|
|
|
def requires_api_key(self) -> bool:
|
|
"""Return True if the provider requires an API key."""
|
|
return False
|
|
|
|
def get_eligibility(self) -> Dict[str, bool]:
|
|
"""Return a dictionary indicating if the provider can query domains and/or IPs."""
|
|
return {'domains': True, 'ips': True}
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if the provider is available and properly configured."""
|
|
return True
|
|
|
|
def query_domain(self, domain: str) -> ProviderResult:
|
|
"""
|
|
Query the provider for information about a domain.
|
|
"""
|
|
return self._find_correlations(domain)
|
|
|
|
def query_ip(self, ip: str) -> ProviderResult:
|
|
"""
|
|
Query the provider for information about an IP address.
|
|
"""
|
|
return self._find_correlations(ip)
|
|
|
|
def set_graph_manager(self, graph_manager: GraphManager):
|
|
"""
|
|
Set the graph manager for the provider to use.
|
|
"""
|
|
self.graph = graph_manager
|
|
|
|
def _find_correlations(self, node_id: str) -> ProviderResult:
|
|
"""
|
|
Find correlations for a given node.
|
|
"""
|
|
result = ProviderResult()
|
|
# FIXED: Ensure self.graph is not None before proceeding.
|
|
if not self.graph or not self.graph.graph.has_node(node_id):
|
|
return result
|
|
|
|
node_attributes = self.graph.graph.nodes[node_id].get('attributes', [])
|
|
|
|
for attr in node_attributes:
|
|
attr_name = attr.get('name')
|
|
attr_value = attr.get('value')
|
|
attr_provider = attr.get('provider', 'unknown')
|
|
|
|
should_exclude = (
|
|
any(excluded_key in attr_name or attr_name == excluded_key for excluded_key in self.EXCLUDED_KEYS) or
|
|
not isinstance(attr_value, (str, int, float, bool)) or
|
|
attr_value is None or
|
|
isinstance(attr_value, bool) or
|
|
(isinstance(attr_value, str) and (
|
|
len(attr_value) < 4 or
|
|
self.date_pattern.match(attr_value) or
|
|
attr_value.lower() in ['unknown', 'none', 'null', 'n/a', 'true', 'false', '0', '1']
|
|
)) or
|
|
(isinstance(attr_value, (int, float)) and (
|
|
attr_value == 0 or
|
|
attr_value == 1 or
|
|
abs(attr_value) > 1000000
|
|
))
|
|
)
|
|
|
|
if should_exclude:
|
|
continue
|
|
|
|
if attr_value not in self.correlation_index:
|
|
self.correlation_index[attr_value] = {
|
|
'nodes': set(),
|
|
'sources': []
|
|
}
|
|
|
|
self.correlation_index[attr_value]['nodes'].add(node_id)
|
|
|
|
source_info = {
|
|
'node_id': node_id,
|
|
'provider': attr_provider,
|
|
'attribute': attr_name,
|
|
'path': f"{attr_provider}_{attr_name}"
|
|
}
|
|
|
|
existing_sources = [s for s in self.correlation_index[attr_value]['sources']
|
|
if s['node_id'] == node_id and s['path'] == source_info['path']]
|
|
if not existing_sources:
|
|
self.correlation_index[attr_value]['sources'].append(source_info)
|
|
|
|
if len(self.correlation_index[attr_value]['nodes']) > 1:
|
|
self._create_correlation_relationships(attr_value, self.correlation_index[attr_value], result)
|
|
return result
|
|
|
|
def _create_correlation_relationships(self, value: Any, correlation_data: Dict[str, Any], result: ProviderResult):
|
|
"""
|
|
Create correlation relationships and add them to the provider result.
|
|
"""
|
|
correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
|
|
nodes = correlation_data['nodes']
|
|
sources = correlation_data['sources']
|
|
|
|
# Add the correlation node as an attribute to the result
|
|
result.add_attribute(
|
|
target_node=correlation_node_id,
|
|
name="correlation_value",
|
|
value=value,
|
|
attr_type=str(type(value)),
|
|
provider=self.name,
|
|
confidence=0.9,
|
|
metadata={
|
|
'correlated_nodes': list(nodes),
|
|
'sources': sources,
|
|
}
|
|
)
|
|
|
|
for source in sources:
|
|
node_id = source['node_id']
|
|
provider = source['provider']
|
|
attribute = source['attribute']
|
|
relationship_label = f"corr_{provider}_{attribute}"
|
|
|
|
# Add the relationship to the result
|
|
result.add_relationship(
|
|
source_node=node_id,
|
|
target_node=correlation_node_id,
|
|
relationship_type=relationship_label,
|
|
provider=self.name,
|
|
confidence=0.9,
|
|
raw_data={
|
|
'correlation_value': value,
|
|
'original_attribute': attribute,
|
|
'correlation_type': 'attribute_matching'
|
|
}
|
|
) |