# DNScope-reduced/core/graph_manager.py """ Graph data model for DNScope using NetworkX. Manages in-memory graph storage with confidence scoring and forensic metadata. Now fully compatible with the unified ProviderResult data model. UPDATED: Fixed correlation exclusion keys to match actual attribute names. UPDATED: Removed export_json() method - now handled by ExportManager. """ import re from datetime import datetime, timezone from enum import Enum from typing import Dict, List, Any, Optional, Tuple import networkx as nx class NodeType(Enum): """Enumeration of supported node types.""" DOMAIN = "domain" IP = "ip" ISP = "isp" CA = "ca" LARGE_ENTITY = "large_entity" CORRELATION_OBJECT = "correlation_object" def __repr__(self): return self.value class GraphManager: """ Thread-safe graph manager for DNScope infrastructure mapping. Uses NetworkX for in-memory graph storage with confidence scoring. Compatible with unified ProviderResult data model. """ def __init__(self): """Initialize empty directed graph.""" self.graph = nx.DiGraph() self.creation_time = datetime.now(timezone.utc).isoformat() self.last_modified = self.creation_time def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None, description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool: """ Add a node to the graph, update attributes, and process correlations. Now compatible with unified data model - attributes are dictionaries from converted StandardAttribute objects. """ is_new_node = not self.graph.has_node(node_id) if is_new_node: self.graph.add_node(node_id, type=node_type.value, added_timestamp=datetime.now(timezone.utc).isoformat(), attributes=attributes or [], # Store as a list from the start description=description, metadata=metadata or {}) else: # Safely merge new attributes into the existing list of attributes if attributes: existing_attributes = self.graph.nodes[node_id].get('attributes', []) # Handle cases where old data might still be in dictionary format if not isinstance(existing_attributes, list): existing_attributes = [] # Create a set of existing attribute names for efficient duplicate checking existing_attr_names = {attr['name'] for attr in existing_attributes} for new_attr in attributes: if new_attr['name'] not in existing_attr_names: existing_attributes.append(new_attr) existing_attr_names.add(new_attr['name']) self.graph.nodes[node_id]['attributes'] = existing_attributes if description: self.graph.nodes[node_id]['description'] = description if metadata: existing_metadata = self.graph.nodes[node_id].get('metadata', {}) existing_metadata.update(metadata) self.graph.nodes[node_id]['metadata'] = existing_metadata self.last_modified = datetime.now(timezone.utc).isoformat() return is_new_node def add_edge(self, source_id: str, target_id: str, relationship_type: str, confidence_score: float = 0.5, source_provider: str = "unknown", raw_data: Optional[Dict[str, Any]] = None) -> bool: """ UPDATED: Add or update an edge between two nodes with raw relationship labels. """ if not self.graph.has_node(source_id) or not self.graph.has_node(target_id): return False new_confidence = confidence_score # UPDATED: Use raw relationship type - no formatting edge_label = relationship_type if self.graph.has_edge(source_id, target_id): # If edge exists, update confidence if the new score is higher. if new_confidence > self.graph.edges[source_id, target_id].get('confidence_score', 0): self.graph.edges[source_id, target_id]['confidence_score'] = new_confidence self.graph.edges[source_id, target_id]['updated_timestamp'] = datetime.now(timezone.utc).isoformat() self.graph.edges[source_id, target_id]['updated_by'] = source_provider return False # Add a new edge with raw attributes self.graph.add_edge(source_id, target_id, relationship_type=edge_label, confidence_score=new_confidence, source_provider=source_provider, discovery_timestamp=datetime.now(timezone.utc).isoformat(), raw_data=raw_data or {}) self.last_modified = datetime.now(timezone.utc).isoformat() return True def remove_node(self, node_id: str) -> bool: """Remove a node and its connected edges from the graph.""" if not self.graph.has_node(node_id): return False # Remove node from the graph (NetworkX handles removing connected edges) self.graph.remove_node(node_id) self.last_modified = datetime.now(timezone.utc).isoformat() return True def get_node_count(self) -> int: """Get total number of nodes in the graph.""" return self.graph.number_of_nodes() def get_edge_count(self) -> int: """Get total number of edges in the graph.""" return self.graph.number_of_edges() def get_nodes_by_type(self, node_type: NodeType) -> List[str]: """Get all nodes of a specific type.""" return [n for n, d in self.graph.nodes(data=True) if d.get('type') == node_type.value] def get_high_confidence_edges(self, min_confidence: float = 0.8) -> List[Tuple[str, str, Dict]]: """Get edges with confidence score above a given threshold.""" return [(u, v, d) for u, v, d in self.graph.edges(data=True) if d.get('confidence_score', 0) >= min_confidence] def get_graph_data(self) -> Dict[str, Any]: """ Export graph data formatted for frontend visualization. SIMPLIFIED: No certificate styling - frontend handles all visual styling. """ nodes = [] for node_id, attrs in self.graph.nodes(data=True): node_data = { 'id': node_id, 'label': node_id, 'type': attrs.get('type', 'unknown'), 'attributes': attrs.get('attributes', []), # Raw attributes list 'description': attrs.get('description', ''), 'metadata': attrs.get('metadata', {}), 'added_timestamp': attrs.get('added_timestamp'), 'max_depth_reached': attrs.get('metadata', {}).get('max_depth_reached', False) } # Add incoming and outgoing edges to node data if self.graph.has_node(node_id): node_data['incoming_edges'] = [ {'from': u, 'data': d} for u, _, d in self.graph.in_edges(node_id, data=True) ] node_data['outgoing_edges'] = [ {'to': v, 'data': d} for _, v, d in self.graph.out_edges(node_id, data=True) ] nodes.append(node_data) edges = [] for source, target, attrs in self.graph.edges(data=True): edges.append({ 'from': source, 'to': target, 'label': attrs.get('relationship_type', ''), 'confidence_score': attrs.get('confidence_score', 0), 'source_provider': attrs.get('source_provider', ''), 'discovery_timestamp': attrs.get('discovery_timestamp') }) return { 'nodes': nodes, 'edges': edges, 'statistics': self.get_statistics()['basic_metrics'] } def _get_confidence_distribution(self) -> Dict[str, int]: """Get distribution of edge confidence scores with empty graph handling.""" distribution = {'high': 0, 'medium': 0, 'low': 0} # FIXED: Handle empty graph case if self.get_edge_count() == 0: return distribution for _, _, data in self.graph.edges(data=True): confidence = data.get('confidence_score', 0) if confidence >= 0.8: distribution['high'] += 1 elif confidence >= 0.6: distribution['medium'] += 1 else: distribution['low'] += 1 return distribution def get_statistics(self) -> Dict[str, Any]: """Get comprehensive statistics about the graph with proper empty graph handling.""" # FIXED: Handle empty graph case properly node_count = self.get_node_count() edge_count = self.get_edge_count() stats = { 'basic_metrics': { 'total_nodes': node_count, 'total_edges': edge_count, 'creation_time': self.creation_time, 'last_modified': self.last_modified }, 'node_type_distribution': {}, 'relationship_type_distribution': {}, 'confidence_distribution': self._get_confidence_distribution(), 'provider_distribution': {} } # FIXED: Only calculate distributions if we have data if node_count > 0: # Calculate node type distributions for node_type in NodeType: count = len(self.get_nodes_by_type(node_type)) if count > 0: # Only include types that exist stats['node_type_distribution'][node_type.value] = count if edge_count > 0: # Calculate edge distributions for _, _, data in self.graph.edges(data=True): rel_type = data.get('relationship_type', 'unknown') stats['relationship_type_distribution'][rel_type] = stats['relationship_type_distribution'].get(rel_type, 0) + 1 provider = data.get('source_provider', 'unknown') stats['provider_distribution'][provider] = stats['provider_distribution'].get(provider, 0) + 1 return stats def clear(self) -> None: """Clear all nodes and edges from the graph.""" self.graph.clear() self.creation_time = datetime.now(timezone.utc).isoformat() self.last_modified = self.creation_time