""" Graph data model for DNSRecon using NetworkX. Manages in-memory graph storage with confidence scoring and forensic metadata. """ import re from datetime import datetime, timezone from enum import Enum from typing import Dict, List, Any, Optional, Tuple import networkx as nx class NodeType(Enum): """Enumeration of supported node types.""" DOMAIN = "domain" IP = "ip" ASN = "asn" LARGE_ENTITY = "large_entity" CORRELATION_OBJECT = "correlation_object" def __repr__(self): return self.value class RelationshipType(Enum): """Enumeration of supported relationship types with confidence scores.""" SAN_CERTIFICATE = ("san", 0.9) A_RECORD = ("a_record", 0.8) AAAA_RECORD = ("aaaa_record", 0.8) CNAME_RECORD = ("cname", 0.8) MX_RECORD = ("mx_record", 0.7) NS_RECORD = ("ns_record", 0.7) PTR_RECORD = ("ptr_record", 0.8) SOA_RECORD = ("soa_record", 0.7) PASSIVE_DNS = ("passive_dns", 0.6) ASN_MEMBERSHIP = ("asn", 0.7) CORRELATED_TO = ("correlated_to", 0.9) def __init__(self, relationship_name: str, default_confidence: float): self.relationship_name = relationship_name self.default_confidence = default_confidence def __repr__(self): return self.relationship_name class GraphManager: """ Thread-safe graph manager for DNSRecon infrastructure mapping. Uses NetworkX for in-memory graph storage with confidence scoring. """ def __init__(self): """Initialize empty directed graph.""" self.graph = nx.DiGraph() self.creation_time = datetime.now(timezone.utc).isoformat() self.last_modified = self.creation_time self.correlation_index = {} # Compile regex for date filtering for efficiency self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') def __getstate__(self): """Prepare GraphManager for pickling, excluding compiled regex.""" state = self.__dict__.copy() # Compiled regex patterns are not always picklable if 'date_pattern' in state: del state['date_pattern'] return state def __setstate__(self, state): """Restore GraphManager state and recompile regex.""" self.__dict__.update(state) self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') def _update_correlation_index(self, node_id: str, data: Any, path: List[str] = None): """Recursively traverse metadata and add hashable values to the index.""" if path is None: path = [] if isinstance(data, dict): for key, value in data.items(): self._update_correlation_index(node_id, value, path + [key]) elif isinstance(data, list): for i, item in enumerate(data): self._update_correlation_index(node_id, item, path + [f"[{i}]"]) else: self._add_to_correlation_index(node_id, data, ".".join(path)) def _add_to_correlation_index(self, node_id: str, value: Any, path_str: str): """Add a hashable value to the correlation index, filtering out noise.""" if not isinstance(value, (str, int, float, bool)) or value is None: return # Ignore certain paths that contain noisy, non-unique identifiers if any(keyword in path_str.lower() for keyword in ['count', 'total', 'timestamp', 'date']): return # Filter out common low-entropy values and date-like strings if isinstance(value, str): # FIXED: Prevent correlation on date/time strings. if self.date_pattern.match(value): return if len(value) < 4 or value.lower() in ['true', 'false', 'unknown', 'none', 'crt.sh']: return elif isinstance(value, int) and abs(value) < 9999: return # Ignore small integers elif isinstance(value, bool): return # Ignore boolean values # Add the valuable correlation data to the index if value not in self.correlation_index: self.correlation_index[value] = {} if node_id not in self.correlation_index[value]: self.correlation_index[value][node_id] = [] if path_str not in self.correlation_index[value][node_id]: self.correlation_index[value][node_id].append(path_str) def _check_for_correlations(self, new_node_id: str, data: Any, path: List[str] = None) -> List[Dict]: """Recursively traverse metadata to find correlations with existing data.""" if path is None: path = [] all_correlations = [] if isinstance(data, dict): for key, value in data.items(): if key == 'source': # Avoid correlating on the provider name continue all_correlations.extend(self._check_for_correlations(new_node_id, value, path + [key])) elif isinstance(data, list): for i, item in enumerate(data): all_correlations.extend(self._check_for_correlations(new_node_id, item, path + [f"[{i}]"])) else: value = data if value in self.correlation_index: existing_nodes_with_paths = self.correlation_index[value] unique_nodes = set(existing_nodes_with_paths.keys()) unique_nodes.add(new_node_id) if len(unique_nodes) < 2: return all_correlations # Correlation must involve at least two distinct nodes new_source = {'node_id': new_node_id, 'path': ".".join(path)} all_sources = [new_source] for node_id, paths in existing_nodes_with_paths.items(): for p_str in paths: all_sources.append({'node_id': node_id, 'path': p_str}) all_correlations.append({ 'value': value, 'sources': all_sources, 'nodes': list(unique_nodes) }) return all_correlations def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[Dict[str, Any]] = None, description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool: """Add a node to the graph, update attributes, and process correlations.""" is_new_node = not self.graph.has_node(node_id) if is_new_node: self.graph.add_node(node_id, type=node_type.value, added_timestamp=datetime.now(timezone.utc).isoformat(), attributes=attributes or {}, description=description, metadata=metadata or {}) else: # Safely merge new attributes into existing attributes if attributes: existing_attributes = self.graph.nodes[node_id].get('attributes', {}) existing_attributes.update(attributes) self.graph.nodes[node_id]['attributes'] = existing_attributes if description: self.graph.nodes[node_id]['description'] = description if metadata: existing_metadata = self.graph.nodes[node_id].get('metadata', {}) existing_metadata.update(metadata) self.graph.nodes[node_id]['metadata'] = existing_metadata if attributes and node_type != NodeType.CORRELATION_OBJECT: correlations = self._check_for_correlations(node_id, attributes) for corr in correlations: value = corr['value'] # FIXED: Check if the correlation value contains an existing node ID. found_major_node_id = None if isinstance(value, str): for existing_node in self.graph.nodes(): if existing_node in value: found_major_node_id = existing_node break if found_major_node_id: # An existing major node is part of the value; link to it directly. for c_node_id in set(corr['nodes']): if self.graph.has_node(c_node_id) and c_node_id != found_major_node_id: self.add_edge(c_node_id, found_major_node_id, RelationshipType.CORRELATED_TO) continue # Skip creating a redundant correlation node # Proceed to create a new correlation node if no major node was found. correlation_node_id = f"{value}" if not self.graph.has_node(correlation_node_id): self.add_node(correlation_node_id, NodeType.CORRELATION_OBJECT, metadata={'value': value, 'sources': corr['sources'], 'correlated_nodes': list(set(corr['nodes']))}) else: # Update existing correlation node existing_meta = self.graph.nodes[correlation_node_id]['metadata'] existing_nodes = set(existing_meta.get('correlated_nodes', [])) existing_meta['correlated_nodes'] = list(existing_nodes.union(set(corr['nodes']))) existing_sources = {(s['node_id'], s['path']) for s in existing_meta.get('sources', [])} for s in corr['sources']: existing_sources.add((s['node_id'], s['path'])) existing_meta['sources'] = [{'node_id': nid, 'path': p} for nid, p in existing_sources] for c_node_id in set(corr['nodes']): self.add_edge(c_node_id, correlation_node_id, RelationshipType.CORRELATED_TO) self._update_correlation_index(node_id, attributes) self.last_modified = datetime.now(timezone.utc).isoformat() return is_new_node def add_edge(self, source_id: str, target_id: str, relationship_type: RelationshipType, confidence_score: Optional[float] = None, source_provider: str = "unknown", raw_data: Optional[Dict[str, Any]] = None) -> bool: """Add or update an edge between two nodes, ensuring nodes exist.""" # LOGIC FIX: Ensure both source and target nodes exist before adding an edge. if not self.graph.has_node(source_id) or not self.graph.has_node(target_id): return False new_confidence = confidence_score or relationship_type.default_confidence if self.graph.has_edge(source_id, target_id): # If edge exists, update confidence if the new score is higher. if new_confidence > self.graph.edges[source_id, target_id].get('confidence_score', 0): self.graph.edges[source_id, target_id]['confidence_score'] = new_confidence self.graph.edges[source_id, target_id]['updated_timestamp'] = datetime.now(timezone.utc).isoformat() self.graph.edges[source_id, target_id]['updated_by'] = source_provider return False # Add a new edge with all attributes. self.graph.add_edge(source_id, target_id, relationship_type=relationship_type.relationship_name, confidence_score=new_confidence, source_provider=source_provider, discovery_timestamp=datetime.now(timezone.utc).isoformat(), raw_data=raw_data or {}) self.last_modified = datetime.now(timezone.utc).isoformat() return True def get_node_count(self) -> int: """Get total number of nodes in the graph.""" return self.graph.number_of_nodes() def get_edge_count(self) -> int: """Get total number of edges in the graph.""" return self.graph.number_of_edges() def get_nodes_by_type(self, node_type: NodeType) -> List[str]: """Get all nodes of a specific type.""" return [n for n, d in self.graph.nodes(data=True) if d.get('type') == node_type.value] def get_neighbors(self, node_id: str) -> List[str]: """Get all unique neighbors (predecessors and successors) for a node.""" if not self.graph.has_node(node_id): return [] return list(set(self.graph.predecessors(node_id)) | set(self.graph.successors(node_id))) def get_high_confidence_edges(self, min_confidence: float = 0.8) -> List[Tuple[str, str, Dict]]: """Get edges with confidence score above a given threshold.""" return [(u, v, d) for u, v, d in self.graph.edges(data=True) if d.get('confidence_score', 0) >= min_confidence] def get_graph_data(self) -> Dict[str, Any]: """Export graph data formatted for frontend visualization.""" nodes = [] for node_id, attrs in self.graph.nodes(data=True): node_data = {'id': node_id, 'label': node_id, 'type': attrs.get('type', 'unknown'), 'attributes': attrs.get('attributes', {}), 'description': attrs.get('description', ''), 'metadata': attrs.get('metadata', {}), 'added_timestamp': attrs.get('added_timestamp')} # Customize node appearance based on type and attributes node_type = node_data['type'] attributes = node_data['attributes'] if node_type == 'domain' and attributes.get('certificates', {}).get('has_valid_cert') is False: node_data['color'] = {'background': '#c7c7c7', 'border': '#999'} # Gray for invalid cert # Add incoming and outgoing edges to node data if self.graph.has_node(node_id): node_data['incoming_edges'] = [{'from': u, 'data': d} for u, _, d in self.graph.in_edges(node_id, data=True)] node_data['outgoing_edges'] = [{'to': v, 'data': d} for _, v, d in self.graph.out_edges(node_id, data=True)] nodes.append(node_data) edges = [] for source, target, attrs in self.graph.edges(data=True): edges.append({'from': source, 'to': target, 'label': attrs.get('relationship_type', ''), 'confidence_score': attrs.get('confidence_score', 0), 'source_provider': attrs.get('source_provider', ''), 'discovery_timestamp': attrs.get('discovery_timestamp')}) return { 'nodes': nodes, 'edges': edges, 'statistics': self.get_statistics()['basic_metrics'] } def export_json(self) -> Dict[str, Any]: """Export complete graph data as a JSON-serializable dictionary.""" graph_data = nx.node_link_data(self.graph) # Use NetworkX's built-in robust serializer return { 'export_metadata': { 'export_timestamp': datetime.now(timezone.utc).isoformat(), 'graph_creation_time': self.creation_time, 'last_modified': self.last_modified, 'total_nodes': self.get_node_count(), 'total_edges': self.get_edge_count(), 'graph_format': 'dnsrecon_v1_nodeling' }, 'graph': graph_data, 'statistics': self.get_statistics() } def _get_confidence_distribution(self) -> Dict[str, int]: """Get distribution of edge confidence scores.""" distribution = {'high': 0, 'medium': 0, 'low': 0} for _, _, confidence in self.graph.edges(data='confidence_score', default=0): if confidence >= 0.8: distribution['high'] += 1 elif confidence >= 0.6: distribution['medium'] += 1 else: distribution['low'] += 1 return distribution def get_statistics(self) -> Dict[str, Any]: """Get comprehensive statistics about the graph.""" stats = {'basic_metrics': {'total_nodes': self.get_node_count(), 'total_edges': self.get_edge_count(), 'creation_time': self.creation_time, 'last_modified': self.last_modified}, 'node_type_distribution': {}, 'relationship_type_distribution': {}, 'confidence_distribution': self._get_confidence_distribution(), 'provider_distribution': {}} # Calculate distributions for node_type in NodeType: stats['node_type_distribution'][node_type.value] = self.get_nodes_by_type(node_type).__len__() for _, _, rel_type in self.graph.edges(data='relationship_type', default='unknown'): stats['relationship_type_distribution'][rel_type] = stats['relationship_type_distribution'].get(rel_type, 0) + 1 for _, _, provider in self.graph.edges(data='source_provider', default='unknown'): stats['provider_distribution'][provider] = stats['provider_distribution'].get(provider, 0) + 1 return stats def clear(self) -> None: """Clear all nodes, edges, and indices from the graph.""" self.graph.clear() self.correlation_index.clear() self.creation_time = datetime.now(timezone.utc).isoformat() self.last_modified = self.creation_time