303 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			303 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# dnsrecon-reduced/core/graph_manager.py
 | 
						|
 | 
						|
"""
 | 
						|
Graph data model for DNSRecon using NetworkX.
 | 
						|
Manages in-memory graph storage with confidence scoring and forensic metadata.
 | 
						|
Now fully compatible with the unified ProviderResult data model.
 | 
						|
FIXED: Added proper pickle support to prevent weakref serialization errors.
 | 
						|
"""
 | 
						|
import re
 | 
						|
from datetime import datetime, timezone
 | 
						|
from enum import Enum
 | 
						|
from typing import Dict, List, Any, Optional, Tuple
 | 
						|
 | 
						|
import networkx as nx
 | 
						|
 | 
						|
 | 
						|
class NodeType(Enum):
 | 
						|
    """Enumeration of supported node types."""
 | 
						|
    DOMAIN = "domain"
 | 
						|
    IP = "ip"
 | 
						|
    ISP = "isp"
 | 
						|
    CA = "ca"
 | 
						|
    LARGE_ENTITY = "large_entity"
 | 
						|
    CORRELATION_OBJECT = "correlation_object"
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return self.value
 | 
						|
 | 
						|
 | 
						|
class GraphManager:
 | 
						|
    """
 | 
						|
    Thread-safe graph manager for DNSRecon infrastructure mapping.
 | 
						|
    Uses NetworkX for in-memory graph storage with confidence scoring.
 | 
						|
    Compatible with unified ProviderResult data model.
 | 
						|
    FIXED: Added proper pickle support to handle NetworkX graph serialization.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        """Initialize empty directed graph."""
 | 
						|
        self.graph = nx.DiGraph()
 | 
						|
        self.creation_time = datetime.now(timezone.utc).isoformat()
 | 
						|
        self.last_modified = self.creation_time
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        """Prepare GraphManager for pickling by converting NetworkX graph to serializable format."""
 | 
						|
        state = self.__dict__.copy()
 | 
						|
        
 | 
						|
        # Convert NetworkX graph to a serializable format
 | 
						|
        if hasattr(self, 'graph') and self.graph:
 | 
						|
            # Extract all nodes with their data
 | 
						|
            nodes_data = {}
 | 
						|
            for node_id, attrs in self.graph.nodes(data=True):
 | 
						|
                nodes_data[node_id] = dict(attrs)
 | 
						|
            
 | 
						|
            # Extract all edges with their data
 | 
						|
            edges_data = []
 | 
						|
            for source, target, attrs in self.graph.edges(data=True):
 | 
						|
                edges_data.append({
 | 
						|
                    'source': source,
 | 
						|
                    'target': target,
 | 
						|
                    'attributes': dict(attrs)
 | 
						|
                })
 | 
						|
            
 | 
						|
            # Replace the NetworkX graph with serializable data
 | 
						|
            state['_graph_nodes'] = nodes_data
 | 
						|
            state['_graph_edges'] = edges_data
 | 
						|
            del state['graph']
 | 
						|
        
 | 
						|
        return state
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        """Restore GraphManager after unpickling by reconstructing NetworkX graph."""
 | 
						|
        # Restore basic attributes
 | 
						|
        self.__dict__.update(state)
 | 
						|
        
 | 
						|
        # Reconstruct NetworkX graph from serializable data
 | 
						|
        self.graph = nx.DiGraph()
 | 
						|
        
 | 
						|
        # Restore nodes
 | 
						|
        if hasattr(self, '_graph_nodes'):
 | 
						|
            for node_id, attrs in self._graph_nodes.items():
 | 
						|
                self.graph.add_node(node_id, **attrs)
 | 
						|
            del self._graph_nodes
 | 
						|
        
 | 
						|
        # Restore edges
 | 
						|
        if hasattr(self, '_graph_edges'):
 | 
						|
            for edge_data in self._graph_edges:
 | 
						|
                self.graph.add_edge(
 | 
						|
                    edge_data['source'], 
 | 
						|
                    edge_data['target'], 
 | 
						|
                    **edge_data['attributes']
 | 
						|
                )
 | 
						|
            del self._graph_edges
 | 
						|
        
 | 
						|
    def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None,
 | 
						|
                description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool:
 | 
						|
        """
 | 
						|
        Add a node to the graph, update attributes, and process correlations.
 | 
						|
        Now compatible with unified data model - attributes are dictionaries from converted StandardAttribute objects.
 | 
						|
        """
 | 
						|
        is_new_node = not self.graph.has_node(node_id)
 | 
						|
        if is_new_node:
 | 
						|
            self.graph.add_node(node_id, type=node_type.value,
 | 
						|
                                added_timestamp=datetime.now(timezone.utc).isoformat(),
 | 
						|
                                attributes=attributes or [],  # Store as a list from the start
 | 
						|
                                description=description,
 | 
						|
                                metadata=metadata or {})
 | 
						|
        else:
 | 
						|
            # Safely merge new attributes into the existing list of attributes
 | 
						|
            if attributes:
 | 
						|
                existing_attributes = self.graph.nodes[node_id].get('attributes', [])
 | 
						|
                
 | 
						|
                # Handle cases where old data might still be in dictionary format
 | 
						|
                if not isinstance(existing_attributes, list):
 | 
						|
                    existing_attributes = []
 | 
						|
 | 
						|
                # Create a set of existing attribute names for efficient duplicate checking
 | 
						|
                existing_attr_names = {attr['name'] for attr in existing_attributes}
 | 
						|
 | 
						|
                for new_attr in attributes:
 | 
						|
                    if new_attr['name'] not in existing_attr_names:
 | 
						|
                        existing_attributes.append(new_attr)
 | 
						|
                        existing_attr_names.add(new_attr['name'])
 | 
						|
                
 | 
						|
                self.graph.nodes[node_id]['attributes'] = existing_attributes
 | 
						|
            if description:
 | 
						|
                self.graph.nodes[node_id]['description'] = description
 | 
						|
            if metadata:
 | 
						|
                existing_metadata = self.graph.nodes[node_id].get('metadata', {})
 | 
						|
                existing_metadata.update(metadata)
 | 
						|
                self.graph.nodes[node_id]['metadata'] = existing_metadata
 | 
						|
 | 
						|
        self.last_modified = datetime.now(timezone.utc).isoformat()
 | 
						|
        return is_new_node
 | 
						|
 | 
						|
    def add_edge(self, source_id: str, target_id: str, relationship_type: str,
 | 
						|
                confidence_score: float = 0.5, source_provider: str = "unknown",
 | 
						|
                raw_data: Optional[Dict[str, Any]] = None) -> bool:
 | 
						|
        """
 | 
						|
        UPDATED: Add or update an edge between two nodes with raw relationship labels.
 | 
						|
        """
 | 
						|
        if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
 | 
						|
            return False
 | 
						|
 | 
						|
        new_confidence = confidence_score
 | 
						|
        
 | 
						|
        # UPDATED: Use raw relationship type - no formatting
 | 
						|
        edge_label = relationship_type
 | 
						|
        
 | 
						|
        if self.graph.has_edge(source_id, target_id):
 | 
						|
            # If edge exists, update confidence if the new score is higher.
 | 
						|
            if new_confidence > self.graph.edges[source_id, target_id].get('confidence_score', 0):
 | 
						|
                self.graph.edges[source_id, target_id]['confidence_score'] = new_confidence
 | 
						|
                self.graph.edges[source_id, target_id]['updated_timestamp'] = datetime.now(timezone.utc).isoformat()
 | 
						|
                self.graph.edges[source_id, target_id]['updated_by'] = source_provider
 | 
						|
            return False
 | 
						|
 | 
						|
        # Add a new edge with raw attributes
 | 
						|
        self.graph.add_edge(source_id, target_id,
 | 
						|
                            relationship_type=edge_label,
 | 
						|
                            confidence_score=new_confidence,
 | 
						|
                            source_provider=source_provider,
 | 
						|
                            discovery_timestamp=datetime.now(timezone.utc).isoformat(),
 | 
						|
                            raw_data=raw_data or {})
 | 
						|
        self.last_modified = datetime.now(timezone.utc).isoformat()
 | 
						|
        return True
 | 
						|
 | 
						|
    def remove_node(self, node_id: str) -> bool:
 | 
						|
        """Remove a node and its connected edges from the graph."""
 | 
						|
        if not self.graph.has_node(node_id):
 | 
						|
            return False
 | 
						|
 | 
						|
        # Remove node from the graph (NetworkX handles removing connected edges)
 | 
						|
        self.graph.remove_node(node_id)
 | 
						|
        
 | 
						|
        self.last_modified = datetime.now(timezone.utc).isoformat()
 | 
						|
        return True
 | 
						|
 | 
						|
    def get_node_count(self) -> int:
 | 
						|
        """Get total number of nodes in the graph."""
 | 
						|
        return self.graph.number_of_nodes()
 | 
						|
 | 
						|
    def get_edge_count(self) -> int:
 | 
						|
        """Get total number of edges in the graph."""
 | 
						|
        return self.graph.number_of_edges()
 | 
						|
 | 
						|
    def get_nodes_by_type(self, node_type: NodeType) -> List[str]:
 | 
						|
        """Get all nodes of a specific type."""
 | 
						|
        return [n for n, d in self.graph.nodes(data=True) if d.get('type') == node_type.value]
 | 
						|
 | 
						|
    def get_high_confidence_edges(self, min_confidence: float = 0.8) -> List[Tuple[str, str, Dict]]:
 | 
						|
        """Get edges with confidence score above a given threshold."""
 | 
						|
        return [(u, v, d) for u, v, d in self.graph.edges(data=True)
 | 
						|
                if d.get('confidence_score', 0) >= min_confidence]
 | 
						|
 | 
						|
    def get_graph_data(self) -> Dict[str, Any]:
 | 
						|
        """
 | 
						|
        Export graph data formatted for frontend visualization.
 | 
						|
        SIMPLIFIED: No certificate styling - frontend handles all visual styling.
 | 
						|
        """
 | 
						|
        nodes = []
 | 
						|
        for node_id, attrs in self.graph.nodes(data=True):
 | 
						|
            node_data = {
 | 
						|
                'id': node_id, 
 | 
						|
                'label': node_id, 
 | 
						|
                'type': attrs.get('type', 'unknown'),
 | 
						|
                'attributes': attrs.get('attributes', []), # Raw attributes list
 | 
						|
                'description': attrs.get('description', ''),
 | 
						|
                'metadata': attrs.get('metadata', {}),
 | 
						|
                'added_timestamp': attrs.get('added_timestamp'),
 | 
						|
                'max_depth_reached': attrs.get('metadata', {}).get('max_depth_reached', False)
 | 
						|
            }
 | 
						|
            
 | 
						|
            # Add incoming and outgoing edges to node data
 | 
						|
            if self.graph.has_node(node_id):
 | 
						|
                node_data['incoming_edges'] = [
 | 
						|
                    {'from': u, 'data': d} for u, _, d in self.graph.in_edges(node_id, data=True)
 | 
						|
                ]
 | 
						|
                node_data['outgoing_edges'] = [
 | 
						|
                    {'to': v, 'data': d} for _, v, d in self.graph.out_edges(node_id, data=True)
 | 
						|
                ]
 | 
						|
            
 | 
						|
            nodes.append(node_data)
 | 
						|
 | 
						|
        edges = []
 | 
						|
        for source, target, attrs in self.graph.edges(data=True):
 | 
						|
            edges.append({
 | 
						|
                'from': source, 
 | 
						|
                'to': target,
 | 
						|
                'label': attrs.get('relationship_type', ''),
 | 
						|
                'confidence_score': attrs.get('confidence_score', 0),
 | 
						|
                'source_provider': attrs.get('source_provider', ''),
 | 
						|
                'discovery_timestamp': attrs.get('discovery_timestamp')
 | 
						|
            })
 | 
						|
        
 | 
						|
        return {
 | 
						|
            'nodes': nodes, 
 | 
						|
            'edges': edges,
 | 
						|
            'statistics': self.get_statistics()['basic_metrics']
 | 
						|
        }
 | 
						|
 | 
						|
    def _get_confidence_distribution(self) -> Dict[str, int]:
 | 
						|
        """Get distribution of edge confidence scores with empty graph handling."""
 | 
						|
        distribution = {'high': 0, 'medium': 0, 'low': 0}
 | 
						|
        
 | 
						|
        # FIXED: Handle empty graph case
 | 
						|
        if self.get_edge_count() == 0:
 | 
						|
            return distribution
 | 
						|
            
 | 
						|
        for _, _, data in self.graph.edges(data=True):
 | 
						|
            confidence = data.get('confidence_score', 0)
 | 
						|
            if confidence >= 0.8:
 | 
						|
                distribution['high'] += 1
 | 
						|
            elif confidence >= 0.6:
 | 
						|
                distribution['medium'] += 1
 | 
						|
            else:
 | 
						|
                distribution['low'] += 1
 | 
						|
        return distribution
 | 
						|
 | 
						|
    def get_statistics(self) -> Dict[str, Any]:
 | 
						|
        """Get comprehensive statistics about the graph with proper empty graph handling."""
 | 
						|
        
 | 
						|
        # FIXED: Handle empty graph case properly
 | 
						|
        node_count = self.get_node_count()
 | 
						|
        edge_count = self.get_edge_count()
 | 
						|
        
 | 
						|
        stats = {
 | 
						|
            'basic_metrics': {
 | 
						|
                'total_nodes': node_count,
 | 
						|
                'total_edges': edge_count,
 | 
						|
                'creation_time': self.creation_time,
 | 
						|
                'last_modified': self.last_modified
 | 
						|
            },
 | 
						|
            'node_type_distribution': {}, 
 | 
						|
            'relationship_type_distribution': {},
 | 
						|
            'confidence_distribution': self._get_confidence_distribution(),
 | 
						|
            'provider_distribution': {}
 | 
						|
        }
 | 
						|
        
 | 
						|
        # FIXED: Only calculate distributions if we have data
 | 
						|
        if node_count > 0:
 | 
						|
            # Calculate node type distributions
 | 
						|
            for node_type in NodeType:
 | 
						|
                count = len(self.get_nodes_by_type(node_type))
 | 
						|
                if count > 0:  # Only include types that exist
 | 
						|
                    stats['node_type_distribution'][node_type.value] = count
 | 
						|
        
 | 
						|
        if edge_count > 0:
 | 
						|
            # Calculate edge distributions
 | 
						|
            for _, _, data in self.graph.edges(data=True):
 | 
						|
                rel_type = data.get('relationship_type', 'unknown')
 | 
						|
                stats['relationship_type_distribution'][rel_type] = stats['relationship_type_distribution'].get(rel_type, 0) + 1
 | 
						|
                
 | 
						|
                provider = data.get('source_provider', 'unknown')
 | 
						|
                stats['provider_distribution'][provider] = stats['provider_distribution'].get(provider, 0) + 1
 | 
						|
        
 | 
						|
        return stats
 | 
						|
 | 
						|
    def clear(self) -> None:
 | 
						|
        """Clear all nodes and edges from the graph."""
 | 
						|
        self.graph.clear()
 | 
						|
        self.creation_time = datetime.now(timezone.utc).isoformat()
 | 
						|
        self.last_modified = self.creation_time |