dnscope/core/graph_manager.py
2025-09-16 23:17:23 +02:00

543 lines
25 KiB
Python

# dnsrecon-reduced/core/graph_manager.py
"""
Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata.
Now fully compatible with the unified ProviderResult data model.
UPDATED: Fixed certificate styling and correlation edge labeling.
"""
import re
from datetime import datetime, timezone
from enum import Enum
from typing import Dict, List, Any, Optional, Tuple
import networkx as nx
class NodeType(Enum):
"""Enumeration of supported node types."""
DOMAIN = "domain"
IP = "ip"
ASN = "asn"
LARGE_ENTITY = "large_entity"
CORRELATION_OBJECT = "correlation_object"
def __repr__(self):
return self.value
class GraphManager:
"""
Thread-safe graph manager for DNSRecon infrastructure mapping.
Uses NetworkX for in-memory graph storage with confidence scoring.
Compatible with unified ProviderResult data model.
"""
def __init__(self):
"""Initialize empty directed graph."""
self.graph = nx.DiGraph()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time
self.correlation_index = {}
# Compile regex for date filtering for efficiency
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
self.EXCLUDED_KEYS = ['confidence', 'provider', 'timestamp', 'type','crtsh_cert_validity_period_days']
def __getstate__(self):
"""Prepare GraphManager for pickling, excluding compiled regex."""
state = self.__dict__.copy()
# Compiled regex patterns are not always picklable
if 'date_pattern' in state:
del state['date_pattern']
return state
def __setstate__(self, state):
"""Restore GraphManager state and recompile regex."""
self.__dict__.update(state)
self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}')
def process_correlations_for_node(self, node_id: str):
"""
UPDATED: Process correlations for a given node with enhanced tracking.
Now properly tracks which attribute/provider created each correlation.
"""
if not self.graph.has_node(node_id):
return
node_attributes = self.graph.nodes[node_id].get('attributes', [])
# Process each attribute for potential correlations
for attr in node_attributes:
attr_name = attr.get('name')
attr_value = attr.get('value')
attr_provider = attr.get('provider', 'unknown')
# Skip excluded attributes and invalid values
if attr_name in self.EXCLUDED_KEYS or not isinstance(attr_value, (str, int, float, bool)) or attr_value is None:
continue
if isinstance(attr_value, bool):
continue
if isinstance(attr_value, str) and (len(attr_value) < 4 or self.date_pattern.match(attr_value)):
continue
# Initialize correlation tracking for this value
if attr_value not in self.correlation_index:
self.correlation_index[attr_value] = {
'nodes': set(),
'sources': [] # Track which provider/attribute combinations contributed
}
# Add this node and source information
self.correlation_index[attr_value]['nodes'].add(node_id)
# Track the source of this correlation value
source_info = {
'node_id': node_id,
'provider': attr_provider,
'attribute': attr_name,
'path': f"{attr_provider}_{attr_name}"
}
# Add source if not already present (avoid duplicates)
existing_sources = [s for s in self.correlation_index[attr_value]['sources']
if s['node_id'] == node_id and s['path'] == source_info['path']]
if not existing_sources:
self.correlation_index[attr_value]['sources'].append(source_info)
# Create correlation node if we have multiple nodes with this value
if len(self.correlation_index[attr_value]['nodes']) > 1:
self._create_enhanced_correlation_node_and_edges(attr_value, self.correlation_index[attr_value])
def _create_enhanced_correlation_node_and_edges(self, value, correlation_data):
"""
UPDATED: Create correlation node and edges with detailed provider tracking.
"""
correlation_node_id = f"corr_{hash(str(value)) & 0x7FFFFFFF}"
nodes = correlation_data['nodes']
sources = correlation_data['sources']
# Create or update correlation node
if not self.graph.has_node(correlation_node_id):
# Determine the most common provider/attribute combination
provider_counts = {}
for source in sources:
key = f"{source['provider']}_{source['attribute']}"
provider_counts[key] = provider_counts.get(key, 0) + 1
# Use the most common provider/attribute as the primary label
primary_source = max(provider_counts.items(), key=lambda x: x[1])[0] if provider_counts else "unknown_correlation"
metadata = {
'value': value,
'correlated_nodes': list(nodes),
'sources': sources,
'primary_source': primary_source,
'correlation_count': len(nodes)
}
self.add_node(correlation_node_id, NodeType.CORRELATION_OBJECT, metadata=metadata)
print(f"Created correlation node {correlation_node_id} for value '{value}' with {len(nodes)} nodes")
# Create edges from each node to the correlation node
for source in sources:
node_id = source['node_id']
provider = source['provider']
attribute = source['attribute']
if self.graph.has_node(node_id) and not self.graph.has_edge(node_id, correlation_node_id):
# Format relationship label as "corr_provider_attribute"
relationship_label = f"corr_{provider}_{attribute}"
self.add_edge(
source_id=node_id,
target_id=correlation_node_id,
relationship_type=relationship_label,
confidence_score=0.9,
source_provider=provider,
raw_data={
'correlation_value': value,
'original_attribute': attribute,
'correlation_type': 'attribute_matching'
}
)
print(f"Added correlation edge: {node_id} -> {correlation_node_id} ({relationship_label})")
def _has_direct_edge_bidirectional(self, node_a: str, node_b: str) -> bool:
"""
Check if there's a direct edge between two nodes in either direction.
Returns True if node_a→node_b OR node_b→node_a exists.
"""
return (self.graph.has_edge(node_a, node_b) or
self.graph.has_edge(node_b, node_a))
def _correlation_value_matches_existing_node(self, correlation_value: str) -> bool:
"""
Check if correlation value contains any existing node ID as substring.
Returns True if match found (correlation node should NOT be created).
"""
correlation_str = str(correlation_value).lower()
# Check against all existing nodes
for existing_node_id in self.graph.nodes():
if existing_node_id.lower() in correlation_str:
return True
return False
def _find_correlation_nodes_with_same_pattern(self, node_set: set) -> List[str]:
"""
Find existing correlation nodes that have the exact same pattern of connected nodes.
Returns list of correlation node IDs with matching patterns.
"""
correlation_nodes = self.get_nodes_by_type(NodeType.CORRELATION_OBJECT)
matching_nodes = []
for corr_node_id in correlation_nodes:
# Get all nodes connected to this correlation node
connected_nodes = set()
# Add all predecessors (nodes pointing TO the correlation node)
connected_nodes.update(self.graph.predecessors(corr_node_id))
# Add all successors (nodes pointed TO by the correlation node)
connected_nodes.update(self.graph.successors(corr_node_id))
# Check if the pattern matches exactly
if connected_nodes == node_set:
matching_nodes.append(corr_node_id)
return matching_nodes
def _merge_correlation_values(self, target_node_id: str, new_value: Any, corr_data: Dict) -> None:
"""
Merge a new correlation value into an existing correlation node.
Uses same logic as large entity merging.
"""
if not self.graph.has_node(target_node_id):
return
target_metadata = self.graph.nodes[target_node_id]['metadata']
# Get existing values (ensure it's a list)
existing_values = target_metadata.get('values', [])
if not isinstance(existing_values, list):
existing_values = [existing_values]
# Add new value if not already present
if new_value not in existing_values:
existing_values.append(new_value)
# Merge sources
existing_sources = target_metadata.get('sources', [])
new_sources = corr_data.get('sources', [])
# Create set of unique sources based on (node_id, path) tuples
source_set = set()
for source in existing_sources + new_sources:
source_tuple = (source['node_id'], source.get('path', ''))
source_set.add(source_tuple)
# Convert back to list of dictionaries
merged_sources = [{'node_id': nid, 'path': path} for nid, path in source_set]
# Update metadata
target_metadata.update({
'values': existing_values,
'sources': merged_sources,
'correlated_nodes': list(set(target_metadata.get('correlated_nodes', []) + corr_data.get('nodes', []))),
'merge_count': len(existing_values),
'last_merge_timestamp': datetime.now(timezone.utc).isoformat()
})
# Update description to reflect merged nature
value_count = len(existing_values)
node_count = len(target_metadata['correlated_nodes'])
self.graph.nodes[target_node_id]['description'] = (
f"Correlation container with {value_count} merged values "
f"across {node_count} nodes"
)
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Add a node to the graph, update attributes, and process correlations.
Now compatible with unified data model - attributes are dictionaries from converted StandardAttribute objects.
"""
is_new_node = not self.graph.has_node(node_id)
if is_new_node:
self.graph.add_node(node_id, type=node_type.value,
added_timestamp=datetime.now(timezone.utc).isoformat(),
attributes=attributes or [], # Store as a list from the start
description=description,
metadata=metadata or {})
else:
# Safely merge new attributes into the existing list of attributes
if attributes:
existing_attributes = self.graph.nodes[node_id].get('attributes', [])
# Handle cases where old data might still be in dictionary format
if not isinstance(existing_attributes, list):
existing_attributes = []
# Create a set of existing attribute names for efficient duplicate checking
existing_attr_names = {attr['name'] for attr in existing_attributes}
for new_attr in attributes:
if new_attr['name'] not in existing_attr_names:
existing_attributes.append(new_attr)
existing_attr_names.add(new_attr['name'])
self.graph.nodes[node_id]['attributes'] = existing_attributes
if description:
self.graph.nodes[node_id]['description'] = description
if metadata:
existing_metadata = self.graph.nodes[node_id].get('metadata', {})
existing_metadata.update(metadata)
self.graph.nodes[node_id]['metadata'] = existing_metadata
self.last_modified = datetime.now(timezone.utc).isoformat()
return is_new_node
def add_edge(self, source_id: str, target_id: str, relationship_type: str,
confidence_score: float = 0.5, source_provider: str = "unknown",
raw_data: Optional[Dict[str, Any]] = None) -> bool:
"""Add or update an edge between two nodes, ensuring nodes exist."""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
return False
new_confidence = confidence_score
if relationship_type.startswith("c_"):
edge_label = relationship_type
else:
edge_label = f"{source_provider}_{relationship_type}"
if self.graph.has_edge(source_id, target_id):
# If edge exists, update confidence if the new score is higher.
if new_confidence > self.graph.edges[source_id, target_id].get('confidence_score', 0):
self.graph.edges[source_id, target_id]['confidence_score'] = new_confidence
self.graph.edges[source_id, target_id]['updated_timestamp'] = datetime.now(timezone.utc).isoformat()
self.graph.edges[source_id, target_id]['updated_by'] = source_provider
return False
# Add a new edge with all attributes.
self.graph.add_edge(source_id, target_id,
relationship_type=edge_label,
confidence_score=new_confidence,
source_provider=source_provider,
discovery_timestamp=datetime.now(timezone.utc).isoformat(),
raw_data=raw_data or {})
self.last_modified = datetime.now(timezone.utc).isoformat()
return True
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
"""
Removes a node from a large entity's internal lists and updates its count.
This prepares the large entity for the node's promotion to a regular node.
"""
if not self.graph.has_node(large_entity_id):
return False
node_data = self.graph.nodes[large_entity_id]
attributes = node_data.get('attributes', {})
# Remove from the list of member nodes
if 'nodes' in attributes and node_id_to_extract in attributes['nodes']:
attributes['nodes'].remove(node_id_to_extract)
# Update the count
attributes['count'] = len(attributes['nodes'])
else:
# This can happen if the node was already extracted, which is not an error.
print(f"Warning: Node {node_id_to_extract} not found in the 'nodes' list of {large_entity_id}.")
return True # Proceed as if successful
self.last_modified = datetime.now(timezone.utc).isoformat()
return True
def remove_node(self, node_id: str) -> bool:
"""Remove a node and its connected edges from the graph."""
if not self.graph.has_node(node_id):
return False
# Remove node from the graph (NetworkX handles removing connected edges)
self.graph.remove_node(node_id)
# Clean up the correlation index
keys_to_delete = []
for value, data in self.correlation_index.items():
if isinstance(data, dict) and 'nodes' in data:
# Updated correlation structure
if node_id in data['nodes']:
data['nodes'].discard(node_id)
# Remove sources for this node
data['sources'] = [s for s in data['sources'] if s['node_id'] != node_id]
if not data['nodes']: # If no other nodes are associated, remove it
keys_to_delete.append(value)
else:
# Legacy correlation structure (fallback)
if isinstance(data, set) and node_id in data:
data.discard(node_id)
if not data:
keys_to_delete.append(value)
for key in keys_to_delete:
if key in self.correlation_index:
del self.correlation_index[key]
self.last_modified = datetime.now(timezone.utc).isoformat()
return True
def get_node_count(self) -> int:
"""Get total number of nodes in the graph."""
return self.graph.number_of_nodes()
def get_edge_count(self) -> int:
"""Get total number of edges in the graph."""
return self.graph.number_of_edges()
def get_nodes_by_type(self, node_type: NodeType) -> List[str]:
"""Get all nodes of a specific type."""
return [n for n, d in self.graph.nodes(data=True) if d.get('type') == node_type.value]
def get_neighbors(self, node_id: str) -> List[str]:
"""Get all unique neighbors (predecessors and successors) for a node."""
if not self.graph.has_node(node_id):
return []
return list(set(self.graph.predecessors(node_id)) | set(self.graph.successors(node_id)))
def get_high_confidence_edges(self, min_confidence: float = 0.8) -> List[Tuple[str, str, Dict]]:
"""Get edges with confidence score above a given threshold."""
return [(u, v, d) for u, v, d in self.graph.edges(data=True)
if d.get('confidence_score', 0) >= min_confidence]
def get_graph_data(self) -> Dict[str, Any]:
"""
Export graph data formatted for frontend visualization.
UPDATED: Fixed certificate validity styling logic for unified data model.
"""
nodes = []
for node_id, attrs in self.graph.nodes(data=True):
node_data = {'id': node_id, 'label': node_id, 'type': attrs.get('type', 'unknown'),
'attributes': attrs.get('attributes', []), # Ensure attributes is a list
'description': attrs.get('description', ''),
'metadata': attrs.get('metadata', {}),
'added_timestamp': attrs.get('added_timestamp')}
# UPDATED: Fixed certificate validity styling logic
node_type = node_data['type']
attributes_list = node_data['attributes']
if node_type == 'domain' and isinstance(attributes_list, list):
# Check for certificate-related attributes
has_certificates = False
has_valid_certificates = False
has_expired_certificates = False
for attr in attributes_list:
attr_name = attr.get('name', '').lower()
attr_provider = attr.get('provider', '').lower()
attr_value = attr.get('value')
# Look for certificate attributes from crt.sh provider
if attr_provider == 'crtsh' or 'cert' in attr_name:
has_certificates = True
# Check certificate validity
if attr_name == 'cert_is_currently_valid':
if attr_value is True:
has_valid_certificates = True
elif attr_value is False:
has_expired_certificates = True
# Also check for certificate expiry indicators
elif 'expires_soon' in attr_name and attr_value is True:
has_expired_certificates = True
elif 'expired' in attr_name and attr_value is True:
has_expired_certificates = True
# Apply styling based on certificate status
if has_expired_certificates and not has_valid_certificates:
# Red for expired/invalid certificates
node_data['color'] = {'background': '#ff6b6b', 'border': '#cc5555'}
elif not has_certificates:
# Grey for domains with no certificates
node_data['color'] = {'background': '#c7c7c7', 'border': '#999999'}
# Default green styling is handled by the frontend for domains with valid certificates
# Add incoming and outgoing edges to node data
if self.graph.has_node(node_id):
node_data['incoming_edges'] = [{'from': u, 'data': d} for u, _, d in self.graph.in_edges(node_id, data=True)]
node_data['outgoing_edges'] = [{'to': v, 'data': d} for _, v, d in self.graph.out_edges(node_id, data=True)]
nodes.append(node_data)
edges = []
for source, target, attrs in self.graph.edges(data=True):
edges.append({'from': source, 'to': target,
'label': attrs.get('relationship_type', ''),
'confidence_score': attrs.get('confidence_score', 0),
'source_provider': attrs.get('source_provider', ''),
'discovery_timestamp': attrs.get('discovery_timestamp')})
return {
'nodes': nodes, 'edges': edges,
'statistics': self.get_statistics()['basic_metrics']
}
def export_json(self) -> Dict[str, Any]:
"""Export complete graph data as a JSON-serializable dictionary."""
graph_data = nx.node_link_data(self.graph) # Use NetworkX's built-in robust serializer
return {
'export_metadata': {
'export_timestamp': datetime.now(timezone.utc).isoformat(),
'graph_creation_time': self.creation_time,
'last_modified': self.last_modified,
'total_nodes': self.get_node_count(),
'total_edges': self.get_edge_count(),
'graph_format': 'dnsrecon_v1_unified_model'
},
'graph': graph_data,
'statistics': self.get_statistics()
}
def _get_confidence_distribution(self) -> Dict[str, int]:
"""Get distribution of edge confidence scores."""
distribution = {'high': 0, 'medium': 0, 'low': 0}
for _, _, data in self.graph.edges(data=True):
confidence = data.get('confidence_score', 0)
if confidence >= 0.8:
distribution['high'] += 1
elif confidence >= 0.6:
distribution['medium'] += 1
else:
distribution['low'] += 1
return distribution
def get_statistics(self) -> Dict[str, Any]:
"""Get comprehensive statistics about the graph."""
stats = {'basic_metrics': {'total_nodes': self.get_node_count(),
'total_edges': self.get_edge_count(),
'creation_time': self.creation_time,
'last_modified': self.last_modified},
'node_type_distribution': {}, 'relationship_type_distribution': {},
'confidence_distribution': self._get_confidence_distribution(),
'provider_distribution': {}}
# Calculate distributions
for node_type in NodeType:
stats['node_type_distribution'][node_type.value] = self.get_nodes_by_type(node_type).__len__()
for _, _, data in self.graph.edges(data=True):
rel_type = data.get('relationship_type', 'unknown')
stats['relationship_type_distribution'][rel_type] = stats['relationship_type_distribution'].get(rel_type, 0) + 1
provider = data.get('source_provider', 'unknown')
stats['provider_distribution'][provider] = stats['provider_distribution'].get(provider, 0) + 1
return stats
def clear(self) -> None:
"""Clear all nodes, edges, and indices from the graph."""
self.graph.clear()
self.correlation_index.clear()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time