dnscope/core/graph_manager.py
2025-09-20 16:52:05 +02:00

303 lines
12 KiB
Python

# dnsrecon-reduced/core/graph_manager.py
"""
Graph data model for DNSRecon using NetworkX.
Manages in-memory graph storage with confidence scoring and forensic metadata.
Now fully compatible with the unified ProviderResult data model.
FIXED: Added proper pickle support to prevent weakref serialization errors.
"""
import re
from datetime import datetime, timezone
from enum import Enum
from typing import Dict, List, Any, Optional, Tuple
import networkx as nx
class NodeType(Enum):
"""Enumeration of supported node types."""
DOMAIN = "domain"
IP = "ip"
ISP = "isp"
CA = "ca"
LARGE_ENTITY = "large_entity"
CORRELATION_OBJECT = "correlation_object"
def __repr__(self):
return self.value
class GraphManager:
"""
Thread-safe graph manager for DNSRecon infrastructure mapping.
Uses NetworkX for in-memory graph storage with confidence scoring.
Compatible with unified ProviderResult data model.
FIXED: Added proper pickle support to handle NetworkX graph serialization.
"""
def __init__(self):
"""Initialize empty directed graph."""
self.graph = nx.DiGraph()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time
def __getstate__(self):
"""Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
state = self.__dict__.copy()
# Convert NetworkX graph to a serializable format
if hasattr(self, 'graph') and self.graph:
# Extract all nodes with their data
nodes_data = {}
for node_id, attrs in self.graph.nodes(data=True):
nodes_data[node_id] = dict(attrs)
# Extract all edges with their data
edges_data = []
for source, target, attrs in self.graph.edges(data=True):
edges_data.append({
'source': source,
'target': target,
'attributes': dict(attrs)
})
# Replace the NetworkX graph with serializable data
state['_graph_nodes'] = nodes_data
state['_graph_edges'] = edges_data
del state['graph']
return state
def __setstate__(self, state):
"""Restore GraphManager after unpickling by reconstructing NetworkX graph."""
# Restore basic attributes
self.__dict__.update(state)
# Reconstruct NetworkX graph from serializable data
self.graph = nx.DiGraph()
# Restore nodes
if hasattr(self, '_graph_nodes'):
for node_id, attrs in self._graph_nodes.items():
self.graph.add_node(node_id, **attrs)
del self._graph_nodes
# Restore edges
if hasattr(self, '_graph_edges'):
for edge_data in self._graph_edges:
self.graph.add_edge(
edge_data['source'],
edge_data['target'],
**edge_data['attributes']
)
del self._graph_edges
def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Add a node to the graph, update attributes, and process correlations.
Now compatible with unified data model - attributes are dictionaries from converted StandardAttribute objects.
"""
is_new_node = not self.graph.has_node(node_id)
if is_new_node:
self.graph.add_node(node_id, type=node_type.value,
added_timestamp=datetime.now(timezone.utc).isoformat(),
attributes=attributes or [], # Store as a list from the start
description=description,
metadata=metadata or {})
else:
# Safely merge new attributes into the existing list of attributes
if attributes:
existing_attributes = self.graph.nodes[node_id].get('attributes', [])
# Handle cases where old data might still be in dictionary format
if not isinstance(existing_attributes, list):
existing_attributes = []
# Create a set of existing attribute names for efficient duplicate checking
existing_attr_names = {attr['name'] for attr in existing_attributes}
for new_attr in attributes:
if new_attr['name'] not in existing_attr_names:
existing_attributes.append(new_attr)
existing_attr_names.add(new_attr['name'])
self.graph.nodes[node_id]['attributes'] = existing_attributes
if description:
self.graph.nodes[node_id]['description'] = description
if metadata:
existing_metadata = self.graph.nodes[node_id].get('metadata', {})
existing_metadata.update(metadata)
self.graph.nodes[node_id]['metadata'] = existing_metadata
self.last_modified = datetime.now(timezone.utc).isoformat()
return is_new_node
def add_edge(self, source_id: str, target_id: str, relationship_type: str,
confidence_score: float = 0.5, source_provider: str = "unknown",
raw_data: Optional[Dict[str, Any]] = None) -> bool:
"""
UPDATED: Add or update an edge between two nodes with raw relationship labels.
"""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
return False
new_confidence = confidence_score
# UPDATED: Use raw relationship type - no formatting
edge_label = relationship_type
if self.graph.has_edge(source_id, target_id):
# If edge exists, update confidence if the new score is higher.
if new_confidence > self.graph.edges[source_id, target_id].get('confidence_score', 0):
self.graph.edges[source_id, target_id]['confidence_score'] = new_confidence
self.graph.edges[source_id, target_id]['updated_timestamp'] = datetime.now(timezone.utc).isoformat()
self.graph.edges[source_id, target_id]['updated_by'] = source_provider
return False
# Add a new edge with raw attributes
self.graph.add_edge(source_id, target_id,
relationship_type=edge_label,
confidence_score=new_confidence,
source_provider=source_provider,
discovery_timestamp=datetime.now(timezone.utc).isoformat(),
raw_data=raw_data or {})
self.last_modified = datetime.now(timezone.utc).isoformat()
return True
def remove_node(self, node_id: str) -> bool:
"""Remove a node and its connected edges from the graph."""
if not self.graph.has_node(node_id):
return False
# Remove node from the graph (NetworkX handles removing connected edges)
self.graph.remove_node(node_id)
self.last_modified = datetime.now(timezone.utc).isoformat()
return True
def get_node_count(self) -> int:
"""Get total number of nodes in the graph."""
return self.graph.number_of_nodes()
def get_edge_count(self) -> int:
"""Get total number of edges in the graph."""
return self.graph.number_of_edges()
def get_nodes_by_type(self, node_type: NodeType) -> List[str]:
"""Get all nodes of a specific type."""
return [n for n, d in self.graph.nodes(data=True) if d.get('type') == node_type.value]
def get_high_confidence_edges(self, min_confidence: float = 0.8) -> List[Tuple[str, str, Dict]]:
"""Get edges with confidence score above a given threshold."""
return [(u, v, d) for u, v, d in self.graph.edges(data=True)
if d.get('confidence_score', 0) >= min_confidence]
def get_graph_data(self) -> Dict[str, Any]:
"""
Export graph data formatted for frontend visualization.
SIMPLIFIED: No certificate styling - frontend handles all visual styling.
"""
nodes = []
for node_id, attrs in self.graph.nodes(data=True):
node_data = {
'id': node_id,
'label': node_id,
'type': attrs.get('type', 'unknown'),
'attributes': attrs.get('attributes', []), # Raw attributes list
'description': attrs.get('description', ''),
'metadata': attrs.get('metadata', {}),
'added_timestamp': attrs.get('added_timestamp'),
'max_depth_reached': attrs.get('metadata', {}).get('max_depth_reached', False)
}
# Add incoming and outgoing edges to node data
if self.graph.has_node(node_id):
node_data['incoming_edges'] = [
{'from': u, 'data': d} for u, _, d in self.graph.in_edges(node_id, data=True)
]
node_data['outgoing_edges'] = [
{'to': v, 'data': d} for _, v, d in self.graph.out_edges(node_id, data=True)
]
nodes.append(node_data)
edges = []
for source, target, attrs in self.graph.edges(data=True):
edges.append({
'from': source,
'to': target,
'label': attrs.get('relationship_type', ''),
'confidence_score': attrs.get('confidence_score', 0),
'source_provider': attrs.get('source_provider', ''),
'discovery_timestamp': attrs.get('discovery_timestamp')
})
return {
'nodes': nodes,
'edges': edges,
'statistics': self.get_statistics()['basic_metrics']
}
def _get_confidence_distribution(self) -> Dict[str, int]:
"""Get distribution of edge confidence scores with empty graph handling."""
distribution = {'high': 0, 'medium': 0, 'low': 0}
# FIXED: Handle empty graph case
if self.get_edge_count() == 0:
return distribution
for _, _, data in self.graph.edges(data=True):
confidence = data.get('confidence_score', 0)
if confidence >= 0.8:
distribution['high'] += 1
elif confidence >= 0.6:
distribution['medium'] += 1
else:
distribution['low'] += 1
return distribution
def get_statistics(self) -> Dict[str, Any]:
"""Get comprehensive statistics about the graph with proper empty graph handling."""
# FIXED: Handle empty graph case properly
node_count = self.get_node_count()
edge_count = self.get_edge_count()
stats = {
'basic_metrics': {
'total_nodes': node_count,
'total_edges': edge_count,
'creation_time': self.creation_time,
'last_modified': self.last_modified
},
'node_type_distribution': {},
'relationship_type_distribution': {},
'confidence_distribution': self._get_confidence_distribution(),
'provider_distribution': {}
}
# FIXED: Only calculate distributions if we have data
if node_count > 0:
# Calculate node type distributions
for node_type in NodeType:
count = len(self.get_nodes_by_type(node_type))
if count > 0: # Only include types that exist
stats['node_type_distribution'][node_type.value] = count
if edge_count > 0:
# Calculate edge distributions
for _, _, data in self.graph.edges(data=True):
rel_type = data.get('relationship_type', 'unknown')
stats['relationship_type_distribution'][rel_type] = stats['relationship_type_distribution'].get(rel_type, 0) + 1
provider = data.get('source_provider', 'unknown')
stats['provider_distribution'][provider] = stats['provider_distribution'].get(provider, 0) + 1
return stats
def clear(self) -> None:
"""Clear all nodes and edges from the graph."""
self.graph.clear()
self.creation_time = datetime.now(timezone.utc).isoformat()
self.last_modified = self.creation_time