diff --git a/core/graph_manager.py b/core/graph_manager.py index 6086ea8..69dbda8 100644 --- a/core/graph_manager.py +++ b/core/graph_manager.py @@ -1,8 +1,9 @@ -# core/graph_manager.py +# dnsrecon-reduced/core/graph_manager.py """ Graph data model for DNSRecon using NetworkX. Manages in-memory graph storage with confidence scoring and forensic metadata. +Now fully compatible with the unified ProviderResult data model. """ import re from datetime import datetime, timezone @@ -28,6 +29,7 @@ class GraphManager: """ Thread-safe graph manager for DNSRecon infrastructure mapping. Uses NetworkX for in-memory graph storage with confidence scoring. + Compatible with unified ProviderResult data model. """ def __init__(self): @@ -192,21 +194,36 @@ class GraphManager: }) return all_correlations - def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[Dict[str, Any]] = None, + def add_node(self, node_id: str, node_type: NodeType, attributes: Optional[List[Dict[str, Any]]] = None, description: str = "", metadata: Optional[Dict[str, Any]] = None) -> bool: - """Add a node to the graph, update attributes, and process correlations.""" + """ + Add a node to the graph, update attributes, and process correlations. + Now compatible with unified data model - attributes are dictionaries from converted StandardAttribute objects. + """ is_new_node = not self.graph.has_node(node_id) if is_new_node: self.graph.add_node(node_id, type=node_type.value, added_timestamp=datetime.now(timezone.utc).isoformat(), - attributes=attributes or {}, + attributes=attributes or [], # Store as a list from the start description=description, metadata=metadata or {}) else: - # Safely merge new attributes into existing attributes + # Safely merge new attributes into the existing list of attributes if attributes: - existing_attributes = self.graph.nodes[node_id].get('attributes', {}) - existing_attributes.update(attributes) + existing_attributes = self.graph.nodes[node_id].get('attributes', []) + + # Handle cases where old data might still be in dictionary format + if not isinstance(existing_attributes, list): + existing_attributes = [] + + # Create a set of existing attribute names for efficient duplicate checking + existing_attr_names = {attr['name'] for attr in existing_attributes} + + for new_attr in attributes: + if new_attr['name'] not in existing_attr_names: + existing_attributes.append(new_attr) + existing_attr_names.add(new_attr['name']) + self.graph.nodes[node_id]['attributes'] = existing_attributes if description: self.graph.nodes[node_id]['description'] = description @@ -485,19 +502,28 @@ class GraphManager: if d.get('confidence_score', 0) >= min_confidence] def get_graph_data(self) -> Dict[str, Any]: - """Export graph data formatted for frontend visualization.""" + """ + Export graph data formatted for frontend visualization. + Compatible with unified data model - preserves all attribute information for frontend display. + """ nodes = [] for node_id, attrs in self.graph.nodes(data=True): node_data = {'id': node_id, 'label': node_id, 'type': attrs.get('type', 'unknown'), - 'attributes': attrs.get('attributes', {}), + 'attributes': attrs.get('attributes', []), # Ensure attributes is a list 'description': attrs.get('description', ''), 'metadata': attrs.get('metadata', {}), 'added_timestamp': attrs.get('added_timestamp')} + # Customize node appearance based on type and attributes node_type = node_data['type'] - attributes = node_data['attributes'] - if node_type == 'domain' and attributes.get('certificates', {}).get('has_valid_cert') is False: - node_data['color'] = {'background': '#c7c7c7', 'border': '#999'} # Gray for invalid cert + attributes_list = node_data['attributes'] + + # CORRECTED LOGIC: Handle certificate validity styling + if node_type == 'domain' and isinstance(attributes_list, list): + # Find the certificates attribute in the list + cert_attr = next((attr for attr in attributes_list if attr.get('name') == 'certificates'), None) + if cert_attr and cert_attr.get('value', {}).get('has_valid_cert') is False: + node_data['color'] = {'background': '#c7c7c7', 'border': '#999'} # Gray for invalid cert # Add incoming and outgoing edges to node data if self.graph.has_node(node_id): @@ -528,7 +554,7 @@ class GraphManager: 'last_modified': self.last_modified, 'total_nodes': self.get_node_count(), 'total_edges': self.get_edge_count(), - 'graph_format': 'dnsrecon_v1_nodeling' + 'graph_format': 'dnsrecon_v1_unified_model' }, 'graph': graph_data, 'statistics': self.get_statistics() diff --git a/core/provider_result.py b/core/provider_result.py new file mode 100644 index 0000000..df2f6f1 --- /dev/null +++ b/core/provider_result.py @@ -0,0 +1,106 @@ +# dnsrecon-reduced/core/provider_result.py + +""" +Unified data model for DNSRecon passive reconnaissance. +Standardizes the data structure across all providers to ensure consistent processing. +""" + +from typing import Any, Optional, List, Dict +from dataclasses import dataclass, field +from datetime import datetime, timezone + + +@dataclass +class StandardAttribute: + """A unified data structure for a single piece of information about a node.""" + target_node: str + name: str + value: Any + type: str + provider: str + confidence: float + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + metadata: Optional[Dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + """Validate the attribute after initialization.""" + if not isinstance(self.confidence, (int, float)) or not 0.0 <= self.confidence <= 1.0: + raise ValueError(f"Confidence must be between 0.0 and 1.0, got {self.confidence}") + + +@dataclass +class Relationship: + """A unified data structure for a directional link between two nodes.""" + source_node: str + target_node: str + relationship_type: str + confidence: float + provider: str + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + raw_data: Optional[Dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + """Validate the relationship after initialization.""" + if not isinstance(self.confidence, (int, float)) or not 0.0 <= self.confidence <= 1.0: + raise ValueError(f"Confidence must be between 0.0 and 1.0, got {self.confidence}") + + +@dataclass +class ProviderResult: + """A container for all data returned by a provider from a single query.""" + attributes: List[StandardAttribute] = field(default_factory=list) + relationships: List[Relationship] = field(default_factory=list) + + def add_attribute(self, target_node: str, name: str, value: Any, attr_type: str, + provider: str, confidence: float = 0.8, + metadata: Optional[Dict[str, Any]] = None) -> None: + """Helper method to add an attribute to the result.""" + self.attributes.append(StandardAttribute( + target_node=target_node, + name=name, + value=value, + type=attr_type, + provider=provider, + confidence=confidence, + metadata=metadata or {} + )) + + def add_relationship(self, source_node: str, target_node: str, relationship_type: str, + provider: str, confidence: float = 0.8, + raw_data: Optional[Dict[str, Any]] = None) -> None: + """Helper method to add a relationship to the result.""" + self.relationships.append(Relationship( + source_node=source_node, + target_node=target_node, + relationship_type=relationship_type, + confidence=confidence, + provider=provider, + raw_data=raw_data or {} + )) + + def get_discovered_nodes(self) -> set: + """Get all unique node identifiers discovered in this result.""" + nodes = set() + + # Add nodes from relationships + for rel in self.relationships: + nodes.add(rel.source_node) + nodes.add(rel.target_node) + + # Add nodes from attributes + for attr in self.attributes: + nodes.add(attr.target_node) + + return nodes + + def get_relationship_count(self) -> int: + """Get the total number of relationships in this result.""" + return len(self.relationships) + + def get_attribute_count(self) -> int: + """Get the total number of attributes in this result.""" + return len(self.attributes) + + def is_large_entity(self, threshold: int) -> bool: + """Check if this result qualifies as a large entity based on relationship count.""" + return self.get_relationship_count() > threshold \ No newline at end of file diff --git a/core/rate_limiter.py b/core/rate_limiter.py index 7fadff4..d5a11d6 100644 --- a/core/rate_limiter.py +++ b/core/rate_limiter.py @@ -1,7 +1,6 @@ # dnsrecon-reduced/core/rate_limiter.py import time -import redis class GlobalRateLimiter: def __init__(self, redis_client): diff --git a/core/scanner.py b/core/scanner.py index 6ce05f3..11b5493 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -2,18 +2,18 @@ import threading import traceback -import time import os import importlib import redis from typing import List, Set, Dict, Any, Tuple, Optional -from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future +from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from queue import PriorityQueue from datetime import datetime, timezone from core.graph_manager import GraphManager, NodeType from core.logger import get_forensic_logger, new_session +from core.provider_result import ProviderResult from utils.helpers import _is_valid_ip, _is_valid_domain from providers.base_provider import BaseProvider from core.rate_limiter import GlobalRateLimiter @@ -30,6 +30,7 @@ class ScanStatus: class Scanner: """ Main scanning orchestrator for DNSRecon passive reconnaissance. + Now provider-agnostic, consuming standardized ProviderResult objects. """ def __init__(self, session_config=None): @@ -470,6 +471,10 @@ class Scanner: print(f" - Tasks processed: {len(processed_tasks)}") def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: + """ + Query a single provider and process the unified ProviderResult. + Now provider-agnostic - handles any provider that returns ProviderResult. + """ if self._is_stop_requested(): print(f"Stop requested before querying {provider.get_name()} for {target}") return set(), set(), False @@ -478,21 +483,24 @@ class Scanner: target_type = NodeType.IP if is_ip else NodeType.DOMAIN print(f"Querying {provider.get_name()} for {target_type.value}: {target} at depth {depth}") + # Ensure target node exists in graph self.graph.add_node(target, target_type) self._initialize_provider_states(target) new_targets = set() large_entity_members = set() - node_attributes = defaultdict(lambda: defaultdict(list)) provider_successful = True try: - provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth) - if provider_results is None: + # Query provider - now returns unified ProviderResult + provider_result = self._query_single_provider_unified(provider, target, is_ip, depth) + + if provider_result is None: provider_successful = False elif not self._is_stop_requested(): - discovered, is_large_entity = self._process_provider_results( - target, provider, provider_results, node_attributes, depth + # Process the unified result + discovered, is_large_entity = self._process_provider_result_unified( + target, provider, provider_result, depth ) if is_large_entity: large_entity_members.update(discovered) @@ -504,15 +512,177 @@ class Scanner: provider_successful = False self._log_provider_error(target, provider.get_name(), str(e)) - if not self._is_stop_requested(): - for node_id, attributes in node_attributes.items(): - if self.graph.graph.has_node(node_id): - node_is_ip = _is_valid_ip(node_id) - node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN - self.graph.add_node(node_id, node_type_to_add, attributes=attributes) - return new_targets, large_entity_members, provider_successful + def _query_single_provider_unified(self, provider: BaseProvider, target: str, is_ip: bool, current_depth: int) -> Optional[ProviderResult]: + """ + Query a single provider with stop signal checking, now returns ProviderResult. + """ + provider_name = provider.get_name() + start_time = datetime.now(timezone.utc) + + if self._is_stop_requested(): + print(f"Stop requested before querying {provider_name} for {target}") + return None + + print(f"Querying {provider_name} for {target}") + + self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}") + + try: + # Query the provider - returns unified ProviderResult + if is_ip: + result = provider.query_ip(target) + else: + result = provider.query_domain(target) + + if self._is_stop_requested(): + print(f"Stop requested after querying {provider_name} for {target}") + return None + + # Update provider state with relationship count (more meaningful than raw result count) + relationship_count = result.get_relationship_count() if result else 0 + self._update_provider_state(target, provider_name, 'success', relationship_count, None, start_time) + + print(f"✓ {provider_name} returned {relationship_count} relationships for {target}") + return result + + except Exception as e: + self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) + print(f"✗ {provider_name} failed for {target}: {e}") + return None + + def _process_provider_result_unified(self, target: str, provider: BaseProvider, + provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: + """ + Process a unified ProviderResult object to update the graph. + Returns (discovered_targets, is_large_entity). + """ + provider_name = provider.get_name() + discovered_targets = set() + + if self._is_stop_requested(): + print(f"Stop requested before processing results from {provider_name} for {target}") + return discovered_targets, False + + # Check for large entity based on relationship count + if provider_result.get_relationship_count() > self.config.large_entity_threshold: + print(f"Large entity detected: {provider_name} returned {provider_result.get_relationship_count()} relationships for {target}") + members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth) + return members, True + + # Process relationships + for i, relationship in enumerate(provider_result.relationships): + if i % 5 == 0 and self._is_stop_requested(): # Check periodically for stop + print(f"Stop requested while processing relationships from {provider_name} for {target}") + break + + # Add nodes for relationship endpoints + source_node = relationship.source_node + target_node = relationship.target_node + + # Determine node types + source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN + if target_node.startswith('AS') and target_node[2:].isdigit(): + target_type = NodeType.ASN + elif _is_valid_ip(target_node): + target_type = NodeType.IP + else: + target_type = NodeType.DOMAIN + + # Add nodes to graph + self.graph.add_node(source_node, source_type) + self.graph.add_node(target_node, target_type) + + # Add edge to graph + if self.graph.add_edge( + source_node, target_node, + relationship.relationship_type, + relationship.confidence, + provider_name, + relationship.raw_data + ): + print(f"Added relationship: {source_node} -> {target_node} ({relationship.relationship_type})") + + # Track discovered targets for further processing + if _is_valid_domain(target_node) or _is_valid_ip(target_node): + discovered_targets.add(target_node) + + # Process attributes, preserving them as a list of objects + attributes_by_node = defaultdict(list) + for attribute in provider_result.attributes: + # Convert the StandardAttribute object to a dictionary that the frontend can use + attr_dict = { + "name": attribute.name, + "value": attribute.value, + "type": attribute.type, + "provider": attribute.provider, + "confidence": attribute.confidence, + "metadata": attribute.metadata + } + attributes_by_node[attribute.target_node].append(attr_dict) + + # Add attributes to nodes + for node_id, node_attributes_list in attributes_by_node.items(): + if self.graph.graph.has_node(node_id): + # Determine node type + if _is_valid_ip(node_id): + node_type = NodeType.IP + elif node_id.startswith('AS') and node_id[2:].isdigit(): + node_type = NodeType.ASN + else: + node_type = NodeType.DOMAIN + + # Add node with the list of attributes + self.graph.add_node(node_id, node_type, attributes=node_attributes_list) + + return discovered_targets, False + + def _create_large_entity_from_provider_result(self, source: str, provider_name: str, + provider_result: ProviderResult, current_depth: int) -> Set[str]: + """ + Create a large entity node from a ProviderResult and return the members for DNS processing. + """ + entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" + + # Extract target nodes from relationships + targets = [rel.target_node for rel in provider_result.relationships] + node_type = 'unknown' + + if targets: + if _is_valid_domain(targets[0]): + node_type = 'domain' + elif _is_valid_ip(targets[0]): + node_type = 'ip' + + # Create nodes in graph (they exist but are grouped) + for target in targets: + target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP + self.graph.add_node(target, target_node_type) + + attributes = { + 'count': len(targets), + 'nodes': targets, + 'node_type': node_type, + 'source_provider': provider_name, + 'discovery_depth': current_depth, + 'threshold_exceeded': self.config.large_entity_threshold, + } + description = f'Large entity created due to {len(targets)} relationships from {provider_name}' + + self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes, description=description) + + # Create edge from source to large entity + if provider_result.relationships: + rel_type = provider_result.relationships[0].relationship_type + self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, + {'large_entity_info': f'Contains {len(targets)} {node_type}s'}) + + self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") + print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}") + + return set(targets) + def stop_scan(self) -> bool: """Request immediate scan termination with proper cleanup.""" try: @@ -558,6 +728,73 @@ class Scanner: traceback.print_exc() return False + def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool: + """ + Extracts a node from a large entity, re-creates its original edge, and + re-queues it for full scanning. + """ + if not self.graph.graph.has_node(large_entity_id): + print(f"ERROR: Large entity {large_entity_id} not found.") + return False + + # 1. Get the original source node that discovered the large entity + predecessors = list(self.graph.graph.predecessors(large_entity_id)) + if not predecessors: + print(f"ERROR: No source node found for large entity {large_entity_id}.") + return False + source_node_id = predecessors[0] + + # Get the original edge data to replicate it for the extracted node + original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id) + if not original_edge_data: + print(f"ERROR: Could not find original edge data from {source_node_id} to {large_entity_id}.") + return False + + # 2. Modify the graph data structure first + success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract) + if not success: + print(f"ERROR: Node {node_id_to_extract} could not be removed from {large_entity_id}'s attributes.") + return False + + # 3. Create the direct edge from the original source to the newly extracted node + print(f"Re-creating direct edge from {source_node_id} to extracted node {node_id_to_extract}") + self.graph.add_edge( + source_id=source_node_id, + target_id=node_id_to_extract, + relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'), + confidence_score=original_edge_data.get('confidence_score', 0.85), # Slightly lower confidence + source_provider=original_edge_data.get('source_provider', 'unknown'), + raw_data={'context': f'Extracted from large entity {large_entity_id}'} + ) + + # 4. Re-queue the extracted node for full processing by all eligible providers + print(f"Re-queueing extracted node {node_id_to_extract} for full reconnaissance...") + is_ip = _is_valid_ip(node_id_to_extract) + current_depth = self.graph.graph.nodes[large_entity_id].get('attributes', {}).get('discovery_depth', 0) + + eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) + for provider in eligible_providers: + provider_name = provider.get_name() + self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth))) + self.total_tasks_ever_enqueued += 1 + + # 5. If the scanner is not running, we need to kickstart it to process this one item. + if self.status != ScanStatus.RUNNING: + print("Scanner is idle. Starting a mini-scan to process the extracted node.") + self.status = ScanStatus.RUNNING + self._update_session_state() + + if not self.scan_thread or not self.scan_thread.is_alive(): + self.scan_thread = threading.Thread( + target=self._execute_scan, + args=(self.current_target, self.max_depth), + daemon=True + ) + self.scan_thread.start() + + print(f"Successfully extracted and re-queued {node_id_to_extract} from {large_entity_id}.") + return True + def _update_session_state(self) -> None: """ Update the scanner state in Redis for GUI updates. @@ -656,39 +893,6 @@ class Scanner: provider_state = provider_states.get(provider_name) return provider_state is not None and provider_state.get('status') == 'success' - def _query_single_provider_forensic(self, provider, target: str, is_ip: bool, current_depth: int) -> Optional[List]: - """Query a single provider with stop signal checking.""" - provider_name = provider.get_name() - start_time = datetime.now(timezone.utc) - - if self._is_stop_requested(): - print(f"Stop requested before querying {provider_name} for {target}") - return None - - print(f"Querying {provider_name} for {target}") - - self.logger.logger.info(f"Attempting {provider_name} query for {target} at depth {current_depth}") - - try: - if is_ip: - results = provider.query_ip(target) - else: - results = provider.query_domain(target) - - if self._is_stop_requested(): - print(f"Stop requested after querying {provider_name} for {target}") - return None - - self._update_provider_state(target, provider_name, 'success', len(results), None, start_time) - - print(f"✓ {provider_name} returned {len(results)} results for {target}") - return results - - except Exception as e: - self._update_provider_state(target, provider_name, 'failed', 0, str(e), start_time) - print(f"✗ {provider_name} failed for {target}: {e}") - return None - def _update_provider_state(self, target: str, provider_name: str, status: str, results_count: int, error: Optional[str], start_time: datetime) -> None: """Update provider state in node metadata for forensic tracking.""" @@ -711,237 +915,6 @@ class Scanner: self.logger.logger.info(f"Provider state updated: {target} -> {provider_name} -> {status} ({results_count} results)") - def _process_provider_results(self, target: str, provider, results: List, - node_attributes: Dict, current_depth: int) -> Tuple[Set[str], bool]: - """Process provider results, returns (discovered_targets, is_large_entity).""" - provider_name = provider.get_name() - discovered_targets = set() - - if self._is_stop_requested(): - print(f"Stop requested before processing results from {provider_name} for {target}") - return discovered_targets, False - - if len(results) > self.config.large_entity_threshold: - print(f"Large entity detected: {provider_name} returned {len(results)} results for {target}") - members = self._create_large_entity(target, provider_name, results, current_depth) - return members, True - - for i, (source, rel_target, rel_type, confidence, raw_data) in enumerate(results): - if i % 5 == 0 and self._is_stop_requested(): # Check more frequently - print(f"Stop requested while processing results from {provider_name} for {target}") - break - - self.logger.log_relationship_discovery( - source_node=source, - target_node=rel_target, - relationship_type=rel_type, - confidence_score=confidence, - provider=provider_name, - raw_data=raw_data, - discovery_method=f"{provider_name}_query_depth_{current_depth}" - ) - - # Collect attributes for the source node - self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source]) - - # If the relationship is asn_membership, collect attributes for the target ASN node - if rel_type == 'asn_membership': - self._collect_node_attributes(rel_target, provider_name, rel_type, source, raw_data, node_attributes[rel_target]) - - - if isinstance(rel_target, list): - # If the target is a list, iterate and process each item - for single_target in rel_target: - if _is_valid_ip(single_target): - self.graph.add_node(single_target, NodeType.IP) - if self.graph.add_edge(source, single_target, rel_type, confidence, provider_name, raw_data): - print(f"Added IP relationship: {source} -> {single_target} ({rel_type})") - discovered_targets.add(single_target) - elif _is_valid_domain(single_target): - self.graph.add_node(single_target, NodeType.DOMAIN) - if self.graph.add_edge(source, single_target, rel_type, confidence, provider_name, raw_data): - print(f"Added domain relationship: {source} -> {single_target} ({rel_type})") - discovered_targets.add(single_target) - self._collect_node_attributes(single_target, provider_name, rel_type, source, raw_data, node_attributes[single_target]) - - elif _is_valid_ip(rel_target): - self.graph.add_node(rel_target, NodeType.IP) - if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): - print(f"Added IP relationship: {source} -> {rel_target} ({rel_type})") - discovered_targets.add(rel_target) - - elif rel_target.startswith('AS') and rel_target[2:].isdigit(): - self.graph.add_node(rel_target, NodeType.ASN) - if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): - print(f"Added ASN relationship: {source} -> {rel_target} ({rel_type})") - - elif _is_valid_domain(rel_target): - self.graph.add_node(rel_target, NodeType.DOMAIN) - if self.graph.add_edge(source, rel_target, rel_type, confidence, provider_name, raw_data): - print(f"Added domain relationship: {source} -> {rel_target} ({rel_type})") - discovered_targets.add(rel_target) - self._collect_node_attributes(rel_target, provider_name, rel_type, source, raw_data, node_attributes[rel_target]) - - else: - self._collect_node_attributes(source, provider_name, rel_type, rel_target, raw_data, node_attributes[source]) - - return discovered_targets, False - - def _create_large_entity(self, source: str, provider_name: str, results: List, current_depth: int) -> Set[str]: - """Create a large entity node and returns the members for DNS processing.""" - entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" - - targets = [rel[1] for rel in results if len(rel) > 1] - node_type = 'unknown' - - if targets: - if _is_valid_domain(targets[0]): - node_type = 'domain' - elif _is_valid_ip(targets[0]): - node_type = 'ip' - - # We still create the nodes so they exist in the graph, they are just not processed for edges yet. - for target in targets: - self.graph.add_node(target, NodeType.DOMAIN if node_type == 'domain' else NodeType.IP) - - attributes = { - 'count': len(targets), - 'nodes': targets, - 'node_type': node_type, - 'source_provider': provider_name, - 'discovery_depth': current_depth, - 'threshold_exceeded': self.config.large_entity_threshold, - } - description = f'Large entity created due to {len(targets)} results from {provider_name}' - - self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes, description=description) - - if results: - rel_type = results[0][2] - self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, - {'large_entity_info': f'Contains {len(targets)} {node_type}s'}) - - self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") - print(f"Created large entity {entity_id} for {len(targets)} {node_type}s from {provider_name}") - - return set(targets) - - def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool: - """ - Extracts a node from a large entity, re-creates its original edge, and - re-queues it for full scanning. - """ - if not self.graph.graph.has_node(large_entity_id): - print(f"ERROR: Large entity {large_entity_id} not found.") - return False - - # 1. Get the original source node that discovered the large entity - predecessors = list(self.graph.graph.predecessors(large_entity_id)) - if not predecessors: - print(f"ERROR: No source node found for large entity {large_entity_id}.") - return False - source_node_id = predecessors[0] - - # Get the original edge data to replicate it for the extracted node - original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id) - if not original_edge_data: - print(f"ERROR: Could not find original edge data from {source_node_id} to {large_entity_id}.") - return False - - # 2. Modify the graph data structure first - success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract) - if not success: - print(f"ERROR: Node {node_id_to_extract} could not be removed from {large_entity_id}'s attributes.") - return False - - # 3. Create the direct edge from the original source to the newly extracted node - print(f"Re-creating direct edge from {source_node_id} to extracted node {node_id_to_extract}") - self.graph.add_edge( - source_id=source_node_id, - target_id=node_id_to_extract, - relationship_type=original_edge_data.get('relationship_type', 'extracted_from_large_entity'), - confidence_score=original_edge_data.get('confidence_score', 0.85), # Slightly lower confidence - source_provider=original_edge_data.get('source_provider', 'unknown'), - raw_data={'context': f'Extracted from large entity {large_entity_id}'} - ) - - # 4. Re-queue the extracted node for full processing by all eligible providers - print(f"Re-queueing extracted node {node_id_to_extract} for full reconnaissance...") - is_ip = _is_valid_ip(node_id_to_extract) - current_depth = self.graph.graph.nodes[large_entity_id].get('attributes', {}).get('discovery_depth', 0) - - eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) - for provider in eligible_providers: - provider_name = provider.get_name() - self.task_queue.put((self._get_priority(provider_name), (provider_name, node_id_to_extract, current_depth))) - self.total_tasks_ever_enqueued += 1 - - # 5. If the scanner is not running, we need to kickstart it to process this one item. - if self.status != ScanStatus.RUNNING: - print("Scanner is idle. Starting a mini-scan to process the extracted node.") - self.status = ScanStatus.RUNNING - self._update_session_state() - - if not self.scan_thread or not self.scan_thread.is_alive(): - self.scan_thread = threading.Thread( - target=self._execute_scan, - args=(self.current_target, self.max_depth), - daemon=True - ) - self.scan_thread.start() - - print(f"Successfully extracted and re-queued {node_id_to_extract} from {large_entity_id}.") - return True - - def _collect_node_attributes(self, node_id: str, provider_name: str, rel_type: str, - target: str, raw_data: Dict[str, Any], attributes: Dict[str, Any]) -> None: - """Collect and organize attributes for a node.""" - self.logger.logger.debug(f"Collecting attributes for {node_id} from {provider_name}: {rel_type}") - - if provider_name == 'dns': - record_type = raw_data.get('query_type', 'UNKNOWN') - value = raw_data.get('value', target) - dns_entry = f"{record_type}: {value}" - if dns_entry not in attributes.get('dns_records', []): - attributes.setdefault('dns_records', []).append(dns_entry) - - elif provider_name == 'crtsh': - if rel_type == "san_certificate": - domain_certs = raw_data.get('domain_certificates', {}) - if node_id in domain_certs: - cert_summary = domain_certs[node_id] - attributes['certificates'] = cert_summary - if target not in attributes.get('related_domains_san', []): - attributes.setdefault('related_domains_san', []).append(target) - - elif provider_name == 'shodan': - # This logic will now apply to the correct node (ASN or IP) - shodan_attributes = attributes.setdefault('shodan', {}) - for key, value in raw_data.items(): - if key not in shodan_attributes or not shodan_attributes.get(key): - shodan_attributes[key] = value - - if _is_valid_ip(node_id): - if 'ports' in raw_data: - attributes['ports'] = raw_data['ports'] - if 'os' in raw_data and raw_data['os']: - attributes['os'] = raw_data['os'] - - if rel_type == "asn_membership": - # This is the key change: these attributes are for the target (the ASN), - # not the source (the IP). We will add them to the ASN node later. - pass - - record_type_name = rel_type - if record_type_name not in attributes: - attributes[record_type_name] = [] - - if isinstance(target, list): - attributes[record_type_name].extend(target) - else: - if target not in attributes[record_type_name]: - attributes[record_type_name].append(target) - def _log_target_processing_error(self, target: str, error: str) -> None: """Log target processing errors for forensic trail.""" self.logger.logger.error(f"Target processing failed for {target}: {error}") diff --git a/core/session_manager.py b/core/session_manager.py index 7631db9..4662c65 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -5,15 +5,11 @@ import time import uuid import redis import pickle -from typing import Dict, Optional, Any, List +from typing import Dict, Optional, Any from core.scanner import Scanner from config import config -# WARNING: Using pickle can be a security risk if the data source is not trusted. -# In this case, we are only serializing/deserializing our own trusted Scanner objects, -# which is generally safe. Do not unpickle data from untrusted sources. - class SessionManager: """ Manages multiple scanner instances for concurrent user sessions using Redis. diff --git a/providers/base_provider.py b/providers/base_provider.py index 7941fb6..9337658 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -4,16 +4,17 @@ import time import requests import threading from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Tuple +from typing import Dict, Any, Optional from core.logger import get_forensic_logger from core.rate_limiter import GlobalRateLimiter +from core.provider_result import ProviderResult class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. - Now supports session-specific configuration. + Now supports session-specific configuration and returns standardized ProviderResult objects. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): @@ -101,7 +102,7 @@ class BaseProvider(ABC): pass @abstractmethod - def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def query_domain(self, domain: str) -> ProviderResult: """ Query the provider for information about a domain. @@ -109,12 +110,12 @@ class BaseProvider(ABC): domain: Domain to investigate Returns: - List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) + ProviderResult containing standardized attributes and relationships """ pass @abstractmethod - def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def query_ip(self, ip: str) -> ProviderResult: """ Query the provider for information about an IP address. @@ -122,7 +123,7 @@ class BaseProvider(ABC): ip: IP address to investigate Returns: - List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) + ProviderResult containing standardized attributes and relationships """ pass diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index bfa2c51..4ccce97 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -2,21 +2,21 @@ import json import re -import os from pathlib import Path -from typing import List, Dict, Any, Tuple, Set +from typing import List, Dict, Any, Set from urllib.parse import quote from datetime import datetime, timezone import requests from .base_provider import BaseProvider +from core.provider_result import ProviderResult from utils.helpers import _is_valid_domain class CrtShProvider(BaseProvider): """ Provider for querying crt.sh certificate transparency database. - Now uses session-specific configuration and caching with accumulative behavior. + Now returns standardized ProviderResult objects with caching support. """ def __init__(self, name=None, session_config=None): @@ -33,6 +33,9 @@ class CrtShProvider(BaseProvider): # Initialize cache directory self.cache_dir = Path('cache') / 'crtsh' self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Compile regex for date filtering for efficiency + self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') def get_name(self) -> str: """Return the provider name.""" @@ -51,16 +54,11 @@ class CrtShProvider(BaseProvider): return {'domains': True, 'ips': False} def is_available(self) -> bool: - """ - Check if the provider is configured to be used. - This method is intentionally simple and does not perform a network request - to avoid blocking application startup. - """ + """Check if the provider is configured to be used.""" return True def _get_cache_file_path(self, domain: str) -> Path: """Generate cache file path for a domain.""" - # Sanitize domain for filename safety safe_domain = domain.replace('.', '_').replace('/', '_').replace('\\', '_') return self.cache_dir / f"{safe_domain}.json" @@ -78,7 +76,7 @@ class CrtShProvider(BaseProvider): last_query_str = cache_data.get("last_upstream_query") if not last_query_str: - return "stale" # Invalid cache format + return "stale" last_query = datetime.fromisoformat(last_query_str.replace('Z', '+00:00')) hours_since_query = (datetime.now(timezone.utc) - last_query).total_seconds() / 3600 @@ -92,27 +90,175 @@ class CrtShProvider(BaseProvider): except (json.JSONDecodeError, ValueError, KeyError) as e: self.logger.logger.warning(f"Invalid cache file format for {cache_file_path}: {e}") return "stale" - - def _load_cached_certificates(self, cache_file_path: Path) -> List[Dict[str, Any]]: - """Load certificates from cache file.""" + + def query_domain(self, domain: str) -> ProviderResult: + """ + Query crt.sh for certificates containing the domain with caching support. + + Args: + domain: Domain to investigate + + Returns: + ProviderResult containing discovered relationships and attributes + """ + if not _is_valid_domain(domain): + return ProviderResult() + + if self._stop_event and self._stop_event.is_set(): + return ProviderResult() + + cache_file = self._get_cache_file_path(domain) + cache_status = self._get_cache_status(cache_file) + + processed_certificates = [] + result = ProviderResult() + + try: + if cache_status == "fresh": + result = self._load_from_cache(cache_file) + self.logger.logger.info(f"Using cached crt.sh data for {domain}") + + else: # "stale" or "not_found" + raw_certificates = self._query_crtsh_api(domain) + + if self._stop_event and self._stop_event.is_set(): + return ProviderResult() + + # Process raw data into the application's expected format + current_processed_certs = [self._extract_certificate_metadata(cert) for cert in raw_certificates] + + if cache_status == "stale": + # Load existing and append new processed certs + existing_result = self._load_from_cache(cache_file) + result = self._merge_results(existing_result, current_processed_certs, domain) + self.logger.logger.info(f"Refreshed and merged cache for {domain}") + else: # "not_found" + # Create new result from processed certs + result = self._process_certificates_to_result(domain, current_processed_certs) + self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)") + + # Save the result to cache + self._save_result_to_cache(cache_file, result, domain) + + except requests.exceptions.RequestException as e: + self.logger.logger.error(f"API query failed for {domain}: {e}") + if cache_status != "not_found": + result = self._load_from_cache(cache_file) + self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.") + else: + raise e # Re-raise if there's no cache to fall back on + + return result + + def query_ip(self, ip: str) -> ProviderResult: + """ + Query crt.sh for certificates containing the IP address. + Note: crt.sh doesn't typically index by IP, so this returns empty results. + + Args: + ip: IP address to investigate + + Returns: + Empty ProviderResult (crt.sh doesn't support IP-based certificate queries effectively) + """ + return ProviderResult() + + def _load_from_cache(self, cache_file_path: Path) -> ProviderResult: + """Load processed crt.sh data from a cache file.""" try: with open(cache_file_path, 'r') as f: - cache_data = json.load(f) - return cache_data.get('certificates', []) + cache_content = json.load(f) + + result = ProviderResult() + + # Reconstruct relationships + for rel_data in cache_content.get("relationships", []): + result.add_relationship( + source_node=rel_data["source_node"], + target_node=rel_data["target_node"], + relationship_type=rel_data["relationship_type"], + provider=rel_data["provider"], + confidence=rel_data["confidence"], + raw_data=rel_data.get("raw_data", {}) + ) + + # Reconstruct attributes + for attr_data in cache_content.get("attributes", []): + result.add_attribute( + target_node=attr_data["target_node"], + name=attr_data["name"], + value=attr_data["value"], + attr_type=attr_data["type"], + provider=attr_data["provider"], + confidence=attr_data["confidence"], + metadata=attr_data.get("metadata", {}) + ) + + return result + except (json.JSONDecodeError, FileNotFoundError, KeyError) as e: self.logger.logger.error(f"Failed to load cached certificates from {cache_file_path}: {e}") - return [] - + return ProviderResult() + + def _save_result_to_cache(self, cache_file_path: Path, result: ProviderResult, domain: str) -> None: + """Save processed crt.sh result to a cache file.""" + try: + cache_data = { + "domain": domain, + "last_upstream_query": datetime.now(timezone.utc).isoformat(), + "relationships": [ + { + "source_node": rel.source_node, + "target_node": rel.target_node, + "relationship_type": rel.relationship_type, + "confidence": rel.confidence, + "provider": rel.provider, + "raw_data": rel.raw_data + } for rel in result.relationships + ], + "attributes": [ + { + "target_node": attr.target_node, + "name": attr.name, + "value": attr.value, + "type": attr.type, + "provider": attr.provider, + "confidence": attr.confidence, + "metadata": attr.metadata + } for attr in result.attributes + ] + } + cache_file_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_file_path, 'w') as f: + json.dump(cache_data, f, separators=(',', ':'), default=str) + except Exception as e: + self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}") + + def _merge_results(self, existing_result: ProviderResult, new_certificates: List[Dict[str, Any]], domain: str) -> ProviderResult: + """Merge new certificate data with existing cached result.""" + # Create a fresh result from the new certificates + new_result = self._process_certificates_to_result(domain, new_certificates) + + # Simple merge strategy: combine all relationships and attributes + # In practice, you might want more sophisticated deduplication + merged_result = ProviderResult() + + # Add existing relationships and attributes + merged_result.relationships.extend(existing_result.relationships) + merged_result.attributes.extend(existing_result.attributes) + + # Add new relationships and attributes + merged_result.relationships.extend(new_result.relationships) + merged_result.attributes.extend(new_result.attributes) + + return merged_result + def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]: - """ - Query crt.sh API for raw certificate data. - Raises exceptions for network errors to allow core logic to retry. - """ + """Query crt.sh API for raw certificate data.""" url = f"{self.base_url}?q={quote(domain)}&output=json" response = self.make_request(url, target_indicator=domain) if not response or response.status_code != 200: - # This could be a temporary error - raise exception so core can retry raise requests.exceptions.RequestException(f"crt.sh API returned status {response.status_code if response else 'None'}") certificates = response.json() @@ -120,126 +266,126 @@ class CrtShProvider(BaseProvider): return [] return certificates - - def _parse_issuer_organization(self, issuer_dn: str) -> str: + + def _process_certificates_to_result(self, domain: str, certificates: List[Dict[str, Any]]) -> ProviderResult: """ - Parse the issuer Distinguished Name to extract just the organization name. + Process certificates to create ProviderResult with relationships and attributes. + """ + result = ProviderResult() - Args: - issuer_dn: Full issuer DN string (e.g., "C=US, O=Let's Encrypt, CN=R11") - - Returns: - Organization name (e.g., "Let's Encrypt") or original string if parsing fails - """ - if not issuer_dn: - return issuer_dn + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh processing cancelled before processing for domain: {domain}") + return result + + # Aggregate certificate data by domain + domain_certificates = {} + all_discovered_domains = set() - try: - # Split by comma and look for O= component - components = [comp.strip() for comp in issuer_dn.split(',')] + # Process certificates with cancellation checking + for i, cert_data in enumerate(certificates): + if i % 5 == 0 and self._stop_event and self._stop_event.is_set(): + print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}") + break + + cert_metadata = self._extract_certificate_metadata(cert_data) + cert_domains = self._extract_domains_from_certificate(cert_data) - for component in components: - if component.startswith('O='): - # Extract the value after O= - org_name = component[2:].strip() - # Remove quotes if present - if org_name.startswith('"') and org_name.endswith('"'): - org_name = org_name[1:-1] - return org_name + all_discovered_domains.update(cert_domains) + for cert_domain in cert_domains: + if not _is_valid_domain(cert_domain): + continue + + if cert_domain not in domain_certificates: + domain_certificates[cert_domain] = [] + + domain_certificates[cert_domain].append(cert_metadata) + + if self._stop_event and self._stop_event.is_set(): + print(f"CrtSh query cancelled before relationship creation for domain: {domain}") + return result + + # Create relationships from query domain to ALL discovered domains + for i, discovered_domain in enumerate(all_discovered_domains): + if discovered_domain == domain: + continue # Skip self-relationships - # If no O= component found, return the original string - return issuer_dn + if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): + print(f"CrtSh relationship creation cancelled for domain: {domain}") + break + + if not _is_valid_domain(discovered_domain): + continue - except Exception as e: - self.logger.logger.debug(f"Failed to parse issuer DN '{issuer_dn}': {e}") - return issuer_dn - - def _parse_certificate_date(self, date_string: str) -> datetime: - """ - Parse certificate date from crt.sh format. + # Get certificates for both domains + query_domain_certs = domain_certificates.get(domain, []) + discovered_domain_certs = domain_certificates.get(discovered_domain, []) + + # Find shared certificates + shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs) + + # Calculate confidence + confidence = self._calculate_domain_relationship_confidence( + domain, discovered_domain, shared_certificates, all_discovered_domains + ) + + # Create comprehensive raw data for the relationship + relationship_raw_data = { + 'relationship_type': 'certificate_discovery', + 'shared_certificates': shared_certificates, + 'total_shared_certs': len(shared_certificates), + 'discovery_context': self._determine_relationship_context(discovered_domain, domain), + 'domain_certificates': { + domain: self._summarize_certificates(query_domain_certs), + discovered_domain: self._summarize_certificates(discovered_domain_certs) + } + } + + # Add relationship + result.add_relationship( + source_node=domain, + target_node=discovered_domain, + relationship_type='san_certificate', + provider=self.name, + confidence=confidence, + raw_data=relationship_raw_data + ) + + # Log the relationship discovery + self.log_relationship_discovery( + source_node=domain, + target_node=discovered_domain, + relationship_type='san_certificate', + confidence_score=confidence, + raw_data=relationship_raw_data, + discovery_method="certificate_transparency_analysis" + ) - Args: - date_string: Date string from crt.sh API + # Add certificate summary as attributes for all domains that have certificates + for cert_domain, cert_list in domain_certificates.items(): + if cert_list: + cert_summary = self._summarize_certificates(cert_list) + + result.add_attribute( + target_node=cert_domain, + name='certificates', + value=cert_summary, + attr_type='certificate_data', + provider=self.name, + confidence=0.9, + metadata={'total_certificates': len(cert_list)} + ) - Returns: - Parsed datetime object in UTC - """ - if not date_string: - raise ValueError("Empty date string") - - try: - # Handle various possible formats from crt.sh - if date_string.endswith('Z'): - return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc) - elif '+' in date_string or date_string.endswith('UTC'): - # Handle timezone-aware strings - date_string = date_string.replace('UTC', '').strip() - if '+' in date_string: - date_string = date_string.split('+')[0] - return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc) - else: - # Assume UTC if no timezone specified - return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc) - except Exception as e: - # Fallback: try parsing without timezone info and assume UTC - try: - return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) - except Exception: - raise ValueError(f"Unable to parse date: {date_string}") from e - - def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool: - """ - Check if a certificate is currently valid based on its expiry date. - - Args: - cert_data: Certificate data from crt.sh - - Returns: - True if certificate is currently valid (not expired) - """ - try: - not_after_str = cert_data.get('not_after') - if not not_after_str: - return False - - not_after_date = self._parse_certificate_date(not_after_str) - not_before_str = cert_data.get('not_before') - - now = datetime.now(timezone.utc) - - # Check if certificate is within valid date range - is_not_expired = not_after_date > now - - if not_before_str: - not_before_date = self._parse_certificate_date(not_before_str) - is_not_before_valid = not_before_date <= now - return is_not_expired and is_not_before_valid - - return is_not_expired - - except Exception as e: - self.logger.logger.debug(f"Certificate validity check failed: {e}") - return False + return result def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]: - """ - Extract comprehensive metadata from certificate data. - - Args: - cert_data: Raw certificate data from crt.sh - - Returns: - Comprehensive certificate metadata dictionary - """ - # Parse the issuer name to get just the organization + """Extract comprehensive metadata from certificate data.""" raw_issuer_name = cert_data.get('issuer_name', '') parsed_issuer_name = self._parse_issuer_organization(raw_issuer_name) metadata = { 'certificate_id': cert_data.get('id'), 'serial_number': cert_data.get('serial_number'), - 'issuer_name': parsed_issuer_name, # Use parsed organization name - #'issuer_name_full': raw_issuer_name, # deliberately left out, because its not useful in most cases + 'issuer_name': parsed_issuer_name, 'issuer_ca_id': cert_data.get('issuer_ca_id'), 'common_name': cert_data.get('common_name'), 'not_before': cert_data.get('not_before'), @@ -257,7 +403,6 @@ class CrtShProvider(BaseProvider): metadata['is_currently_valid'] = self._is_cert_valid(cert_data) metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30 - # Add human-readable dates metadata['not_before'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC') metadata['not_after'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC') @@ -268,220 +413,134 @@ class CrtShProvider(BaseProvider): return metadata - def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: - """ - Query crt.sh for certificates containing the domain with caching support. - Properly raises exceptions for network errors to allow core logic retries. - """ - if not _is_valid_domain(domain): - return [] + def _parse_issuer_organization(self, issuer_dn: str) -> str: + """Parse the issuer Distinguished Name to extract just the organization name.""" + if not issuer_dn: + return issuer_dn - if self._stop_event and self._stop_event.is_set(): - return [] + try: + components = [comp.strip() for comp in issuer_dn.split(',')] + + for component in components: + if component.startswith('O='): + org_name = component[2:].strip() + if org_name.startswith('"') and org_name.endswith('"'): + org_name = org_name[1:-1] + return org_name + + return issuer_dn + + except Exception as e: + self.logger.logger.debug(f"Failed to parse issuer DN '{issuer_dn}': {e}") + return issuer_dn - cache_file = self._get_cache_file_path(domain) - cache_status = self._get_cache_status(cache_file) - - processed_certificates = [] + def _parse_certificate_date(self, date_string: str) -> datetime: + """Parse certificate date from crt.sh format.""" + if not date_string: + raise ValueError("Empty date string") try: - if cache_status == "fresh": - processed_certificates = self._load_cached_certificates(cache_file) - self.logger.logger.info(f"Using cached processed data for {domain} ({len(processed_certificates)} certificates)") - - else: # "stale" or "not_found" - raw_certificates = self._query_crtsh_api(domain) - - if self._stop_event and self._stop_event.is_set(): - return [] - - # Process raw data into the application's expected format - current_processed_certs = [self._extract_certificate_metadata(cert) for cert in raw_certificates] - - if cache_status == "stale": - # Append new processed certs to existing ones - processed_certificates = self._append_to_cache(cache_file, current_processed_certs) - self.logger.logger.info(f"Refreshed and appended cache for {domain}") - else: # "not_found" - # Create a new cache file with the processed certs, even if empty - self._create_cache_file(cache_file, domain, current_processed_certs) - processed_certificates = current_processed_certs - self.logger.logger.info(f"Cached fresh data for {domain} ({len(processed_certificates)} certificates)") - - - except requests.exceptions.RequestException as e: - self.logger.logger.error(f"API query failed for {domain}: {e}") - if cache_status != "not_found": - processed_certificates = self._load_cached_certificates(cache_file) - self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.") + if date_string.endswith('Z'): + return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc) + elif '+' in date_string or date_string.endswith('UTC'): + date_string = date_string.replace('UTC', '').strip() + if '+' in date_string: + date_string = date_string.split('+')[0] + return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc) else: - raise e # Re-raise if there's no cache to fall back on - - if not processed_certificates: - return [] - - return self._process_certificates_to_relationships(domain, processed_certificates) - - def _create_cache_file(self, cache_file_path: Path, domain: str, processed_certificates: List[Dict[str, Any]]) -> None: - """Create new cache file with processed certificates.""" - try: - cache_data = { - "domain": domain, - "last_upstream_query": datetime.now(timezone.utc).isoformat(), - "certificates": processed_certificates # Store processed data - } - cache_file_path.parent.mkdir(parents=True, exist_ok=True) - with open(cache_file_path, 'w') as f: - json.dump(cache_data, f, separators=(',', ':')) + return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc) except Exception as e: - self.logger.logger.warning(f"Failed to create cache file for {domain}: {e}") + try: + return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) + except Exception: + raise ValueError(f"Unable to parse date: {date_string}") from e - def _append_to_cache(self, cache_file_path: Path, new_processed_certificates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Append new processed certificates to existing cache and return all certificates.""" + def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool: + """Check if a certificate is currently valid based on its expiry date.""" try: - with open(cache_file_path, 'r') as f: - cache_data = json.load(f) - - existing_ids = {cert.get('certificate_id') for cert in cache_data.get('certificates', [])} - - for cert in new_processed_certificates: - if cert.get('certificate_id') not in existing_ids: - cache_data['certificates'].append(cert) - - cache_data['last_upstream_query'] = datetime.now(timezone.utc).isoformat() - - with open(cache_file_path, 'w') as f: - json.dump(cache_data, f, separators=(',', ':')) - - return cache_data['certificates'] + not_after_str = cert_data.get('not_after') + if not not_after_str: + return False + + not_after_date = self._parse_certificate_date(not_after_str) + not_before_str = cert_data.get('not_before') + + now = datetime.now(timezone.utc) + is_not_expired = not_after_date > now + + if not_before_str: + not_before_date = self._parse_certificate_date(not_before_str) + is_not_before_valid = not_before_date <= now + return is_not_expired and is_not_before_valid + + return is_not_expired + except Exception as e: - self.logger.logger.warning(f"Failed to append to cache: {e}") - return new_processed_certificates - - def _process_certificates_to_relationships(self, domain: str, certificates: List[Dict[str, Any]]) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: - """ - Process certificates to relationships using existing logic. - This method contains the original processing logic from query_domain. - """ - relationships = [] + self.logger.logger.debug(f"Certificate validity check failed: {e}") + return False + + def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: + """Extract all domains from certificate data.""" + domains = set() - # Check for cancellation before processing - if self._stop_event and self._stop_event.is_set(): - print(f"CrtSh processing cancelled before processing for domain: {domain}") + # Extract from common name + common_name = cert_data.get('common_name', '') + if common_name: + cleaned_cn = self._clean_domain_name(common_name) + if cleaned_cn: + domains.update(cleaned_cn) + + # Extract from name_value field (contains SANs) + name_value = cert_data.get('name_value', '') + if name_value: + for line in name_value.split('\n'): + cleaned_domains = self._clean_domain_name(line.strip()) + if cleaned_domains: + domains.update(cleaned_domains) + + return domains + + def _clean_domain_name(self, domain_name: str) -> List[str]: + """Clean and normalize domain name from certificate data.""" + if not domain_name: return [] - # Aggregate certificate data by domain - domain_certificates = {} - all_discovered_domains = set() - - # Process certificates with cancellation checking - for i, cert_data in enumerate(certificates): - # Check for cancellation every 5 certificates for faster response - if i % 5 == 0 and self._stop_event and self._stop_event.is_set(): - print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}") - break - - cert_metadata = self._extract_certificate_metadata(cert_data) - cert_domains = self._extract_domains_from_certificate(cert_data) - - # Add all domains from this certificate to our tracking - all_discovered_domains.update(cert_domains) - for cert_domain in cert_domains: - if not _is_valid_domain(cert_domain): - continue - - # Initialize domain certificate list if needed - if cert_domain not in domain_certificates: - domain_certificates[cert_domain] = [] - - # Add this certificate to the domain's certificate list - domain_certificates[cert_domain].append(cert_metadata) - - # Final cancellation check before creating relationships - if self._stop_event and self._stop_event.is_set(): - print(f"CrtSh query cancelled before relationship creation for domain: {domain}") - return [] + domain = domain_name.strip().lower() - # Create relationships from query domain to ALL discovered domains with stop checking - for i, discovered_domain in enumerate(all_discovered_domains): - if discovered_domain == domain: - continue # Skip self-relationships - - # Check for cancellation every 10 relationships - if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): - print(f"CrtSh relationship creation cancelled for domain: {domain}") - break + if domain.startswith(('http://', 'https://')): + domain = domain.split('://', 1)[1] - if not _is_valid_domain(discovered_domain): - continue - - # Get certificates for both domains - query_domain_certs = domain_certificates.get(domain, []) - discovered_domain_certs = domain_certificates.get(discovered_domain, []) - - # Find shared certificates (for metadata purposes) - shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs) - - # Calculate confidence based on relationship type and shared certificates - confidence = self._calculate_domain_relationship_confidence( - domain, discovered_domain, shared_certificates, all_discovered_domains - ) - - # Create comprehensive raw data for the relationship - relationship_raw_data = { - 'relationship_type': 'certificate_discovery', - 'shared_certificates': shared_certificates, - 'total_shared_certs': len(shared_certificates), - 'discovery_context': self._determine_relationship_context(discovered_domain, domain), - 'domain_certificates': { - domain: self._summarize_certificates(query_domain_certs), - discovered_domain: self._summarize_certificates(discovered_domain_certs) - } - } - - # Create domain -> domain relationship - relationships.append(( - domain, - discovered_domain, - 'san_certificate', - confidence, - relationship_raw_data - )) - - # Log the relationship discovery - self.log_relationship_discovery( - source_node=domain, - target_node=discovered_domain, - relationship_type='san_certificate', - confidence_score=confidence, - raw_data=relationship_raw_data, - discovery_method="certificate_transparency_analysis" - ) + if '/' in domain: + domain = domain.split('/', 1)[0] - return relationships + if ':' in domain and not domain.count(':') > 1: + domain = domain.split(':', 1)[0] + + cleaned_domains = [] + if domain.startswith('*.'): + cleaned_domains.append(domain) + cleaned_domains.append(domain[2:]) + else: + cleaned_domains.append(domain) + + final_domains = [] + for d in cleaned_domains: + d = re.sub(r'[^\w\-\.]', '', d) + if d and not d.startswith(('.', '-')) and not d.endswith(('.', '-')): + final_domains.append(d) + + return [d for d in final_domains if _is_valid_domain(d)] def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Find certificates that are shared between two domain certificate lists. - - Args: - certs1: First domain's certificates - certs2: Second domain's certificates - - Returns: - List of shared certificate metadata - """ + """Find certificates that are shared between two domain certificate lists.""" shared = [] - # Create a set of certificate IDs from the first list for quick lookup cert1_ids = set() for cert in certs1: cert_id = cert.get('certificate_id') - # Ensure the ID is not None and is a hashable type before adding to the set if cert_id and isinstance(cert_id, (int, str, float, bool, tuple)): cert1_ids.add(cert_id) - # Find certificates in the second list that match for cert in certs2: cert_id = cert.get('certificate_id') if cert_id and isinstance(cert_id, (int, str, float, bool, tuple)): @@ -491,15 +550,7 @@ class CrtShProvider(BaseProvider): return shared def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]: - """ - Create a summary of certificates for a domain. - - Args: - certificates: List of certificate metadata - - Returns: - Summary dictionary with aggregate statistics - """ + """Create a summary of certificates for a domain.""" if not certificates: return { 'total_certificates': 0, @@ -509,14 +560,13 @@ class CrtShProvider(BaseProvider): 'unique_issuers': [], 'latest_certificate': None, 'has_valid_cert': False, - 'certificate_details': [] # Always include empty list + 'certificate_details': [] } valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid')) expired_count = len(certificates) - valid_count expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon')) - # Get unique issuers (using parsed organization names) unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name'))) # Find the most recent certificate @@ -548,63 +598,40 @@ class CrtShProvider(BaseProvider): 'unique_issuers': unique_issuers, 'latest_certificate': latest_cert, 'has_valid_cert': valid_count > 0, - 'certificate_details': sorted_certificates # Include full certificate details + 'certificate_details': sorted_certificates } def _get_certificate_sort_date(self, cert: Dict[str, Any]) -> datetime: - """ - Get a sortable date from certificate data for chronological ordering. - - Args: - cert: Certificate metadata dictionary - - Returns: - Datetime object for sorting (falls back to epoch if parsing fails) - """ + """Get a sortable date from certificate data for chronological ordering.""" try: - # Try not_before first (issue date) if cert.get('not_before'): return self._parse_certificate_date(cert['not_before']) - # Fall back to entry_timestamp if available if cert.get('entry_timestamp'): return self._parse_certificate_date(cert['entry_timestamp']) - # Last resort - return a very old date for certificates without dates return datetime(1970, 1, 1, tzinfo=timezone.utc) except Exception: - # If all parsing fails, return epoch return datetime(1970, 1, 1, tzinfo=timezone.utc) def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str, shared_certificates: List[Dict[str, Any]], all_discovered_domains: Set[str]) -> float: - """ - Calculate confidence score for domain relationship based on various factors. - - Args: - domain1: Source domain (query domain) - domain2: Target domain (discovered domain) - shared_certificates: List of shared certificate metadata - all_discovered_domains: All domains discovered in this query - - Returns: - Confidence score between 0.0 and 1.0 - """ + """Calculate confidence score for domain relationship based on various factors.""" base_confidence = 0.9 # Adjust confidence based on domain relationship context relationship_context = self._determine_relationship_context(domain2, domain1) if relationship_context == 'exact_match': - context_bonus = 0.0 # This shouldn't happen, but just in case + context_bonus = 0.0 elif relationship_context == 'subdomain': - context_bonus = 0.1 # High confidence for subdomains + context_bonus = 0.1 elif relationship_context == 'parent_domain': - context_bonus = 0.05 # Medium confidence for parent domains + context_bonus = 0.05 else: - context_bonus = 0.0 # Related domains get base confidence + context_bonus = 0.0 # Adjust confidence based on shared certificates if shared_certificates: @@ -616,18 +643,16 @@ class CrtShProvider(BaseProvider): else: shared_bonus = 0.02 - # Additional bonus for valid shared certificates valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid')) if valid_shared > 0: validity_bonus = 0.05 else: validity_bonus = 0.0 else: - # Even without shared certificates, domains found in the same query have some relationship shared_bonus = 0.0 validity_bonus = 0.0 - # Adjust confidence based on certificate issuer reputation (if shared certificates exist) + # Adjust confidence based on certificate issuer reputation issuer_bonus = 0.0 if shared_certificates: for cert in shared_certificates: @@ -636,21 +661,11 @@ class CrtShProvider(BaseProvider): issuer_bonus = max(issuer_bonus, 0.03) break - # Calculate final confidence final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus - return max(0.1, min(1.0, final_confidence)) # Clamp between 0.1 and 1.0 + return max(0.1, min(1.0, final_confidence)) def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str: - """ - Determine the context of the relationship between certificate domain and query domain. - - Args: - cert_domain: Domain found in certificate - query_domain: Original query domain - - Returns: - String describing the relationship context - """ + """Determine the context of the relationship between certificate domain and query domain.""" if cert_domain == query_domain: return 'exact_match' elif cert_domain.endswith(f'.{query_domain}'): @@ -658,88 +673,4 @@ class CrtShProvider(BaseProvider): elif query_domain.endswith(f'.{cert_domain}'): return 'parent_domain' else: - return 'related_domain' - - def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: - """ - Query crt.sh for certificates containing the IP address. - Note: crt.sh doesn't typically index by IP, so this returns empty results. - - Args: - ip: IP address to investigate - - Returns: - Empty list (crt.sh doesn't support IP-based certificate queries effectively) - """ - # crt.sh doesn't effectively support IP-based certificate queries - return [] - - def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: - """ - Extract all domains from certificate data. - - Args: - cert_data: Certificate data from crt.sh API - - Returns: - Set of unique domain names found in the certificate - """ - domains = set() - - # Extract from common name - common_name = cert_data.get('common_name', '') - if common_name: - cleaned_cn = self._clean_domain_name(common_name) - if cleaned_cn: - domains.update(cleaned_cn) - - # Extract from name_value field (contains SANs) - name_value = cert_data.get('name_value', '') - if name_value: - # Split by newlines and clean each domain - for line in name_value.split('\n'): - cleaned_domains = self._clean_domain_name(line.strip()) - if cleaned_domains: - domains.update(cleaned_domains) - - return domains - - def _clean_domain_name(self, domain_name: str) -> List[str]: - """ - Clean and normalize domain name from certificate data. - Now returns a list to handle wildcards correctly. - """ - if not domain_name: - return [] - - domain = domain_name.strip().lower() - - # Remove protocol if present - if domain.startswith(('http://', 'https://')): - domain = domain.split('://', 1)[1] - - # Remove path if present - if '/' in domain: - domain = domain.split('/', 1)[0] - - # Remove port if present - if ':' in domain and not domain.count(':') > 1: # Avoid breaking IPv6 - domain = domain.split(':', 1)[0] - - # Handle wildcard domains - cleaned_domains = [] - if domain.startswith('*.'): - # Add both the wildcard and the base domain - cleaned_domains.append(domain) - cleaned_domains.append(domain[2:]) - else: - cleaned_domains.append(domain) - - # Remove any remaining invalid characters and validate - final_domains = [] - for d in cleaned_domains: - d = re.sub(r'[^\w\-\.]', '', d) - if d and not d.startswith(('.', '-')) and not d.endswith(('.', '-')): - final_domains.append(d) - - return [d for d in final_domains if _is_valid_domain(d)] \ No newline at end of file + return 'related_domain' \ No newline at end of file diff --git a/providers/dns_provider.py b/providers/dns_provider.py index d73ef6c..5d972d1 100644 --- a/providers/dns_provider.py +++ b/providers/dns_provider.py @@ -1,15 +1,16 @@ # dnsrecon/providers/dns_provider.py from dns import resolver, reversename -from typing import List, Dict, Any, Tuple +from typing import Dict from .base_provider import BaseProvider +from core.provider_result import ProviderResult from utils.helpers import _is_valid_ip, _is_valid_domain class DNSProvider(BaseProvider): """ Provider for standard DNS resolution and reverse DNS lookups. - Now uses session-specific configuration. + Now returns standardized ProviderResult objects. """ def __init__(self, name=None, session_config=None): @@ -25,7 +26,6 @@ class DNSProvider(BaseProvider): self.resolver = resolver.Resolver() self.resolver.timeout = 5 self.resolver.lifetime = 10 - #self.resolver.nameservers = ['127.0.0.1'] def get_name(self) -> str: """Return the provider name.""" @@ -47,31 +47,35 @@ class DNSProvider(BaseProvider): """DNS is always available - no API key required.""" return True - def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def query_domain(self, domain: str) -> ProviderResult: """ - Query DNS records for the domain to discover relationships. - ... + Query DNS records for the domain to discover relationships and attributes. + + Args: + domain: Domain to investigate + + Returns: + ProviderResult containing discovered relationships and attributes """ if not _is_valid_domain(domain): - return [] + return ProviderResult() - relationships = [] + result = ProviderResult() # Query all record types for record_type in ['A', 'AAAA', 'CNAME', 'MX', 'NS', 'SOA', 'TXT', 'SRV', 'CAA']: try: - relationships.extend(self._query_record(domain, record_type)) + self._query_record(domain, record_type, result) except resolver.NoAnswer: # This is not an error, just a confirmation that the record doesn't exist. self.logger.logger.debug(f"No {record_type} record found for {domain}") except Exception as e: self.failed_requests += 1 self.logger.logger.debug(f"{record_type} record query failed for {domain}: {e}") - # Optionally, you might want to re-raise other, more serious exceptions. - return relationships + return result - def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def query_ip(self, ip: str) -> ProviderResult: """ Query reverse DNS for the IP address. @@ -79,12 +83,12 @@ class DNSProvider(BaseProvider): ip: IP address to investigate Returns: - List of relationships discovered from reverse DNS + ProviderResult containing discovered relationships and attributes """ if not _is_valid_ip(ip): - return [] + return ProviderResult() - relationships = [] + result = ProviderResult() try: # Perform reverse DNS lookup @@ -97,27 +101,44 @@ class DNSProvider(BaseProvider): hostname = str(ptr_record).rstrip('.') if _is_valid_domain(hostname): - raw_data = { - 'query_type': 'PTR', - 'ip_address': ip, - 'hostname': hostname, - 'ttl': response.ttl - } + # Add the relationship + result.add_relationship( + source_node=ip, + target_node=hostname, + relationship_type='ptr_record', + provider=self.name, + confidence=0.8, + raw_data={ + 'query_type': 'PTR', + 'ip_address': ip, + 'hostname': hostname, + 'ttl': response.ttl + } + ) - relationships.append(( - ip, - hostname, - 'ptr_record', - 0.8, - raw_data - )) + # Add PTR record as attribute to the IP + result.add_attribute( + target_node=ip, + name='ptr_record', + value=hostname, + attr_type='dns_record', + provider=self.name, + confidence=0.8, + metadata={'ttl': response.ttl} + ) + # Log the relationship discovery self.log_relationship_discovery( source_node=ip, target_node=hostname, relationship_type='ptr_record', confidence_score=0.8, - raw_data=raw_data, + raw_data={ + 'query_type': 'PTR', + 'ip_address': ip, + 'hostname': hostname, + 'ttl': response.ttl + }, discovery_method="reverse_dns_lookup" ) @@ -130,18 +151,24 @@ class DNSProvider(BaseProvider): # Re-raise the exception so the scanner can handle the failure raise e - return relationships + return result - def _query_record(self, domain: str, record_type: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def _query_record(self, domain: str, record_type: str, result: ProviderResult) -> None: """ - Query a specific type of DNS record for the domain. + Query a specific type of DNS record for the domain and add results to ProviderResult. + + Args: + domain: Domain to query + record_type: DNS record type (A, AAAA, CNAME, etc.) + result: ProviderResult to populate """ - relationships = [] try: self.total_requests += 1 response = self.resolver.resolve(domain, record_type) self.successful_requests += 1 + dns_records = [] + for record in response: target = "" if record_type in ['A', 'AAAA']: @@ -153,12 +180,16 @@ class DNSProvider(BaseProvider): elif record_type == 'SOA': target = str(record.mname).rstrip('.') elif record_type in ['TXT']: - # TXT records are treated as metadata, not relationships. + # TXT records are treated as attributes, not relationships + txt_value = str(record).strip('"') + dns_records.append(f"TXT: {txt_value}") continue elif record_type == 'SRV': target = str(record.target).rstrip('.') elif record_type == 'CAA': - target = f"{record.flags} {record.tag.decode('utf-8')} \"{record.value.decode('utf-8')}\"" + caa_value = f"{record.flags} {record.tag.decode('utf-8')} \"{record.value.decode('utf-8')}\"" + dns_records.append(f"CAA: {caa_value}") + continue else: target = str(record) @@ -170,16 +201,22 @@ class DNSProvider(BaseProvider): 'ttl': response.ttl } relationship_type = f"{record_type.lower()}_record" - confidence = 0.8 # Default confidence for DNS records + confidence = 0.8 # Standard confidence for DNS records - relationships.append(( - domain, - target, - relationship_type, - confidence, - raw_data - )) + # Add relationship + result.add_relationship( + source_node=domain, + target_node=target, + relationship_type=relationship_type, + provider=self.name, + confidence=confidence, + raw_data=raw_data + ) + # Add DNS record as attribute to the source domain + dns_records.append(f"{record_type}: {target}") + + # Log relationship discovery self.log_relationship_discovery( source_node=domain, target_node=target, @@ -189,10 +226,20 @@ class DNSProvider(BaseProvider): discovery_method=f"dns_{record_type.lower()}_record" ) + # Add DNS records as a consolidated attribute + if dns_records: + result.add_attribute( + target_node=domain, + name='dns_records', + value=dns_records, + attr_type='dns_record_list', + provider=self.name, + confidence=0.8, + metadata={'record_types': [record_type]} + ) + except Exception as e: self.failed_requests += 1 self.logger.logger.debug(f"{record_type} record query failed for {domain}: {e}") # Re-raise the exception so the scanner can handle it - raise e - - return relationships \ No newline at end of file + raise e \ No newline at end of file diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 30c48f5..21b530c 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -1,20 +1,20 @@ # dnsrecon/providers/shodan_provider.py import json -import os from pathlib import Path -from typing import List, Dict, Any, Tuple +from typing import Dict, Any from datetime import datetime, timezone import requests from .base_provider import BaseProvider +from core.provider_result import ProviderResult from utils.helpers import _is_valid_ip, _is_valid_domain class ShodanProvider(BaseProvider): """ Provider for querying Shodan API for IP address information. - Now uses session-specific API keys, is limited to IP-only queries, and includes caching. + Now returns standardized ProviderResult objects with caching support. """ def __init__(self, name=None, session_config=None): @@ -85,88 +85,156 @@ class ShodanProvider(BaseProvider): except (json.JSONDecodeError, ValueError, KeyError): return "stale" - def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def query_domain(self, domain: str) -> ProviderResult: """ Domain queries are no longer supported for the Shodan provider. + + Args: + domain: Domain to investigate + + Returns: + Empty ProviderResult """ - return [] + return ProviderResult() - def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def query_ip(self, ip: str) -> ProviderResult: """ - Query Shodan for information about an IP address, with caching of processed relationships. + Query Shodan for information about an IP address, with caching of processed data. + + Args: + ip: IP address to investigate + + Returns: + ProviderResult containing discovered relationships and attributes """ if not _is_valid_ip(ip) or not self.is_available(): - return [] + return ProviderResult() cache_file = self._get_cache_file_path(ip) cache_status = self._get_cache_status(cache_file) - relationships = [] + result = ProviderResult() try: if cache_status == "fresh": - relationships = self._load_from_cache(cache_file) - self.logger.logger.info(f"Using cached Shodan relationships for {ip}") - else: # "stale" or "not_found" + result = self._load_from_cache(cache_file) + self.logger.logger.info(f"Using cached Shodan data for {ip}") + else: # "stale" or "not_found" url = f"{self.base_url}/shodan/host/{ip}" params = {'key': self.api_key} response = self.make_request(url, method="GET", params=params, target_indicator=ip) if response and response.status_code == 200: data = response.json() - # Process the data into relationships BEFORE caching - relationships = self._process_shodan_data(ip, data) - self._save_to_cache(cache_file, relationships) # Save the processed relationships + # Process the data into ProviderResult BEFORE caching + result = self._process_shodan_data(ip, data) + self._save_to_cache(cache_file, result, data) # Save both result and raw data elif cache_status == "stale": # If API fails on a stale cache, use the old data - relationships = self._load_from_cache(cache_file) + result = self._load_from_cache(cache_file) except requests.exceptions.RequestException as e: self.logger.logger.error(f"Shodan API query failed for {ip}: {e}") if cache_status == "stale": - relationships = self._load_from_cache(cache_file) + result = self._load_from_cache(cache_file) - return relationships + return result - def _load_from_cache(self, cache_file_path: Path) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: - """Load processed Shodan relationships from a cache file.""" + def _load_from_cache(self, cache_file_path: Path) -> ProviderResult: + """Load processed Shodan data from a cache file.""" try: with open(cache_file_path, 'r') as f: cache_content = json.load(f) - # The entire file content is the list of relationships - return cache_content.get("relationships", []) + + result = ProviderResult() + + # Reconstruct relationships + for rel_data in cache_content.get("relationships", []): + result.add_relationship( + source_node=rel_data["source_node"], + target_node=rel_data["target_node"], + relationship_type=rel_data["relationship_type"], + provider=rel_data["provider"], + confidence=rel_data["confidence"], + raw_data=rel_data.get("raw_data", {}) + ) + + # Reconstruct attributes + for attr_data in cache_content.get("attributes", []): + result.add_attribute( + target_node=attr_data["target_node"], + name=attr_data["name"], + value=attr_data["value"], + attr_type=attr_data["type"], + provider=attr_data["provider"], + confidence=attr_data["confidence"], + metadata=attr_data.get("metadata", {}) + ) + + return result + except (json.JSONDecodeError, FileNotFoundError, KeyError): - return [] + return ProviderResult() - def _save_to_cache(self, cache_file_path: Path, relationships: List[Tuple[str, str, str, float, Dict[str, Any]]]) -> None: - """Save processed Shodan relationships to a cache file.""" + def _save_to_cache(self, cache_file_path: Path, result: ProviderResult, raw_data: Dict[str, Any]) -> None: + """Save processed Shodan data to a cache file.""" try: cache_data = { "last_upstream_query": datetime.now(timezone.utc).isoformat(), - "relationships": relationships + "raw_data": raw_data, # Preserve original for forensic purposes + "relationships": [ + { + "source_node": rel.source_node, + "target_node": rel.target_node, + "relationship_type": rel.relationship_type, + "confidence": rel.confidence, + "provider": rel.provider, + "raw_data": rel.raw_data + } for rel in result.relationships + ], + "attributes": [ + { + "target_node": attr.target_node, + "name": attr.name, + "value": attr.value, + "type": attr.type, + "provider": attr.provider, + "confidence": attr.confidence, + "metadata": attr.metadata + } for attr in result.attributes + ] } with open(cache_file_path, 'w') as f: - json.dump(cache_data, f, separators=(',', ':')) + json.dump(cache_data, f, separators=(',', ':'), default=str) except Exception as e: self.logger.logger.warning(f"Failed to save Shodan cache for {cache_file_path.name}: {e}") - def _process_shodan_data(self, ip: str, data: Dict[str, Any]) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: + def _process_shodan_data(self, ip: str, data: Dict[str, Any]) -> ProviderResult: """ - Process Shodan data to extract relationships. + Process Shodan data to extract relationships and attributes. + + Args: + ip: IP address queried + data: Raw Shodan response data + + Returns: + ProviderResult with relationships and attributes """ - relationships = [] + result = ProviderResult() # Extract hostname relationships hostnames = data.get('hostnames', []) for hostname in hostnames: if _is_valid_domain(hostname): - relationships.append(( - ip, - hostname, - 'a_record', - 0.8, - data - )) + result.add_relationship( + source_node=ip, + target_node=hostname, + relationship_type='a_record', + provider=self.name, + confidence=0.8, + raw_data=data + ) + self.log_relationship_discovery( source_node=ip, target_node=hostname, @@ -180,13 +248,15 @@ class ShodanProvider(BaseProvider): asn = data.get('asn') if asn: asn_name = f"AS{asn[2:]}" if isinstance(asn, str) and asn.startswith('AS') else f"AS{asn}" - relationships.append(( - ip, - asn_name, - 'asn_membership', - 0.7, - data - )) + result.add_relationship( + source_node=ip, + target_node=asn_name, + relationship_type='asn_membership', + provider=self.name, + confidence=0.7, + raw_data=data + ) + self.log_relationship_discovery( source_node=ip, target_node=asn_name, @@ -195,5 +265,67 @@ class ShodanProvider(BaseProvider): raw_data=data, discovery_method="shodan_asn_lookup" ) - - return relationships \ No newline at end of file + + # Add comprehensive Shodan host information as attributes + if 'ports' in data: + result.add_attribute( + target_node=ip, + name='ports', + value=data['ports'], + attr_type='network_info', + provider=self.name, + confidence=0.9 + ) + + if 'os' in data and data['os']: + result.add_attribute( + target_node=ip, + name='operating_system', + value=data['os'], + attr_type='system_info', + provider=self.name, + confidence=0.8 + ) + + if 'org' in data: + result.add_attribute( + target_node=ip, + name='organization', + value=data['org'], + attr_type='network_info', + provider=self.name, + confidence=0.8 + ) + + if 'country_name' in data: + result.add_attribute( + target_node=ip, + name='country', + value=data['country_name'], + attr_type='location_info', + provider=self.name, + confidence=0.9 + ) + + if 'city' in data: + result.add_attribute( + target_node=ip, + name='city', + value=data['city'], + attr_type='location_info', + provider=self.name, + confidence=0.8 + ) + + # Store complete Shodan data as a comprehensive attribute + result.add_attribute( + target_node=ip, + name='shodan_host_info', + value=data, # Complete Shodan response for full forensic detail + attr_type='comprehensive_data', + provider=self.name, + confidence=0.9, + metadata={'data_source': 'shodan_api', 'query_type': 'host_lookup'} + ) + + return result \ No newline at end of file diff --git a/static/js/graph.js b/static/js/graph.js index f1c703f..6f8d6f3 100644 --- a/static/js/graph.js +++ b/static/js/graph.js @@ -1,6 +1,7 @@ /** * Graph visualization module for DNSRecon * Handles network graph rendering using vis.js with proper large entity node hiding + * UPDATED: Now compatible with unified data model (StandardAttribute objects) */ const contextMenuCSS = ` .graph-context-menu { @@ -380,11 +381,15 @@ class GraphManager { const largeEntityMap = new Map(); graphData.nodes.forEach(node => { - if (node.type === 'large_entity' && node.attributes && Array.isArray(node.attributes.nodes)) { - node.attributes.nodes.forEach(nodeId => { - largeEntityMap.set(nodeId, node.id); - this.largeEntityMembers.add(nodeId); - }); + if (node.type === 'large_entity' && node.attributes) { + // UPDATED: Handle unified data model - look for 'nodes' attribute in the attributes list + const nodesAttribute = this.findAttributeByName(node.attributes, 'nodes'); + if (nodesAttribute && Array.isArray(nodesAttribute.value)) { + nodesAttribute.value.forEach(nodeId => { + largeEntityMap.set(nodeId, node.id); + this.largeEntityMembers.add(nodeId); + }); + } } }); @@ -466,8 +471,21 @@ class GraphManager { } /** - * Process node data with styling and metadata - * @param {Object} node - Raw node data + * UPDATED: Helper method to find an attribute by name in the standardized attributes list + * @param {Array} attributes - List of StandardAttribute objects + * @param {string} name - Attribute name to find + * @returns {Object|null} The attribute object if found, null otherwise + */ + findAttributeByName(attributes, name) { + if (!Array.isArray(attributes)) { + return null; + } + return attributes.find(attr => attr.name === name) || null; + } + + /** + * UPDATED: Process node data with styling and metadata for unified data model + * @param {Object} node - Raw node data with standardized attributes * @returns {Object} Processed node data */ processNode(node) { @@ -478,7 +496,7 @@ class GraphManager { size: this.getNodeSize(node.type), borderColor: this.getNodeBorderColor(node.type), shape: this.getNodeShape(node.type), - attributes: node.attributes || {}, + attributes: node.attributes || [], // Keep as standardized attributes list description: node.description || '', metadata: node.metadata || {}, type: node.type, @@ -491,9 +509,10 @@ class GraphManager { processedNode.borderWidth = Math.max(2, Math.floor(node.confidence * 5)); } - // Style based on certificate validity + // UPDATED: Style based on certificate validity using unified data model if (node.type === 'domain') { - if (node.attributes && node.attributes.certificates && node.attributes.certificates.has_valid_cert === false) { + const certificatesAttr = this.findAttributeByName(node.attributes, 'certificates'); + if (certificatesAttr && certificatesAttr.value && certificatesAttr.value.has_valid_cert === false) { processedNode.color = { background: '#888888', border: '#666666' }; } } diff --git a/static/js/main.js b/static/js/main.js index 7c7b920..5ce2477 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -1,6 +1,7 @@ /** * Main application logic for DNSRecon web interface * Handles UI interactions, API communication, and data flow + * UPDATED: Now compatible with unified data model (StandardAttribute objects) */ class DNSReconApp { @@ -808,10 +809,22 @@ class DNSReconApp { } /** - * Enhanced node details HTML generation with better visual hierarchy - * File: static/js/main.js (replace generateNodeDetailsHtml method) + * UPDATED: Helper method to find an attribute by name in the standardized attributes list + * @param {Array} attributes - List of StandardAttribute objects + * @param {string} name - Attribute name to find + * @returns {Object|null} The attribute object if found, null otherwise */ + findAttributeByName(attributes, name) { + if (!Array.isArray(attributes)) { + return null; + } + return attributes.find(attr => attr.name === name) || null; + } + /** + * UPDATED: Enhanced node details HTML generation for unified data model + * Now processes StandardAttribute objects instead of simple key-value pairs + */ generateNodeDetailsHtml(node) { if (!node) return '
Details not available.
'; @@ -857,23 +870,28 @@ class DNSReconApp { return detailsHtml; } + /** + * UPDATED: Generate details for standard nodes using unified data model + */ generateStandardNodeDetails(node) { let html = ''; // Relationships sections html += this.generateRelationshipsSection(node); - // Enhanced attributes section with special certificate handling - if (node.attributes && Object.keys(node.attributes).length > 0) { - const { certificates, ...otherAttributes } = node.attributes; + // UPDATED: Enhanced attributes section with special certificate handling for unified model + if (node.attributes && Array.isArray(node.attributes) && node.attributes.length > 0) { + // Find certificate attribute separately + const certificatesAttr = this.findAttributeByName(node.attributes, 'certificates'); // Handle certificates separately with enhanced display - if (certificates) { - html += this.generateCertificateSection({ certificates }); + if (certificatesAttr) { + html += this.generateCertificateSection(certificatesAttr); } - // Handle other attributes normally - if (Object.keys(otherAttributes).length > 0) { + // Handle other attributes normally (excluding certificates to avoid duplication) + const otherAttributes = node.attributes.filter(attr => attr.name !== 'certificates'); + if (otherAttributes.length > 0) { html += this.generateAttributesSection(otherAttributes); } } @@ -888,10 +906,10 @@ class DNSReconApp { } /** - * Enhanced certificate section generation using existing styles + * UPDATED: Enhanced certificate section generation for unified data model */ - generateCertificateSection(attributes) { - const certificates = attributes.certificates; + generateCertificateSection(certificatesAttr) { + const certificates = certificatesAttr.value; if (!certificates || typeof certificates !== 'object') { return ''; } @@ -1094,10 +1112,22 @@ class DNSReconApp { return html; } + /** + * UPDATED: Generate large entity details using unified data model + */ generateLargeEntityDetails(node) { - const attributes = node.attributes || {}; - const nodes = attributes.nodes || []; - const nodeType = attributes.node_type || 'nodes'; + // UPDATED: Look for attributes in the unified model structure + const nodesAttribute = this.findAttributeByName(node.attributes, 'nodes'); + const countAttribute = this.findAttributeByName(node.attributes, 'count'); + const nodeTypeAttribute = this.findAttributeByName(node.attributes, 'node_type'); + const sourceProviderAttribute = this.findAttributeByName(node.attributes, 'source_provider'); + const discoveryDepthAttribute = this.findAttributeByName(node.attributes, 'discovery_depth'); + + const nodes = nodesAttribute ? nodesAttribute.value : []; + const count = countAttribute ? countAttribute.value : 0; + const nodeType = nodeTypeAttribute ? nodeTypeAttribute.value : 'nodes'; + const sourceProvider = sourceProviderAttribute ? sourceProviderAttribute.value : 'Unknown'; + const discoveryDepth = discoveryDepthAttribute ? discoveryDepthAttribute.value : 'Unknown'; let html = ` @@ -1132,17 +1162,19 @@ class DNSReconApp { // Use node.id for the large_entity_id const largeEntityId = node.id; - nodes.forEach(innerNodeId => { - html += ` -
- ${innerNodeId} - -
- `; - }); + if (Array.isArray(nodes)) { + nodes.forEach(innerNodeId => { + html += ` +
+ ${innerNodeId} + +
+ `; + }); + } html += ''; @@ -1255,151 +1287,6 @@ class DNSReconApp { return valueSourceMap; } - generateCorrelationObjectLayout(node) { - const metadata = node.metadata || {}; - const values = metadata.values || []; - const mergeCount = metadata.merge_count || 1; - - let html = '
'; - - if (mergeCount > 1) { - html += ` -
-
-

🔗Merged Correlations

-
${mergeCount} values
-
-
- `; - - values.forEach((value, index) => { - const displayValue = typeof value === 'string' && value.length > 50 ? - value.substring(0, 47) + '...' : value; - - html += ` -
-
${displayValue}
- - -
- `; - }); - - html += '
'; - } else { - const singleValue = values.length > 0 ? values[0] : (metadata.value || 'Unknown'); - html += ` -
-
-

🔗Correlation Value

-
-
${singleValue}
-
- `; - } - - // Show correlated nodes - const correlatedNodes = metadata.correlated_nodes || []; - if (correlatedNodes.length > 0) { - html += ` -
-
-

🌐Correlated Nodes

-
${correlatedNodes.length}
-
-
- `; - - correlatedNodes.forEach(nodeId => { - html += ` - - `; - }); - - html += '
'; - } - - html += '
'; - return html; - } - - generateLargeEntityLayout(node) { - const attributes = node.attributes || {}; - const nodes = attributes.nodes || []; - const nodeType = attributes.node_type || 'nodes'; - - let html = ` -
-
-
-

📦Large Entity Container

-
${attributes.count} ${nodeType}s
-
-
-
- Source Provider: - ${attributes.source_provider || 'Unknown'} -
-
- Discovery Depth: - ${attributes.discovery_depth || 'Unknown'} -
-
-
- -
-
-

📋Contained ${nodeType}s

- -
-
- `; - - nodes.forEach((innerNodeId, index) => { - const innerNode = this.graphManager.nodes.get(innerNodeId); - html += ` -
-
- - ${innerNodeId} - -
-
- ${innerNode ? this.generateStandardNodeLayout(innerNode) : '
No details available
'} -
-
- `; - }); - - html += '
'; - return html; - } - - generateStandardNodeLayout(node) { - let html = '
'; - - // Relationships section - html += this.generateRelationshipsSection(node); - - // Attributes section with smart categorization - html += this.generateAttributesSection(node); - - // Description section - html += this.generateDescriptionSection(node); - - // Metadata section (collapsed by default) - html += this.generateMetadataSection(node); - - html += '
'; - return html; - } - generateRelationshipsSection(node) { let html = ''; @@ -1468,12 +1355,20 @@ class DNSReconApp { return html; } + /** + * UPDATED: Generate attributes section for unified data model + * Now processes StandardAttribute objects instead of key-value pairs + */ generateAttributesSection(attributes) { - const categorized = this.categorizeAttributes(attributes); + if (!Array.isArray(attributes) || attributes.length === 0) { + return ''; + } + + const categorized = this.categorizeStandardAttributes(attributes); let html = ''; Object.entries(categorized).forEach(([category, attrs]) => { - if (Object.keys(attrs).length === 0) return; + if (attrs.length === 0) return; html += ` '; }); @@ -1503,47 +1394,41 @@ class DNSReconApp { return html; } - formatCertificateData(certData) { - if (!certData || typeof certData !== 'object') { - return '

No certificate data available

'; - } + /** + * UPDATED: Categorize StandardAttribute objects by type and content + */ + categorizeStandardAttributes(attributes) { + const categories = { + 'DNS Records': [], + 'Network Info': [], + 'Provider Data': [], + 'Other': [] + }; - let html = '
'; + attributes.forEach(attr => { + const lowerName = attr.name.toLowerCase(); + const attrType = attr.type ? attr.type.toLowerCase() : ''; + + if (lowerName.includes('dns') || lowerName.includes('record') || attrType.includes('dns')) { + categories['DNS Records'].push(attr); + } else if (lowerName.includes('ip') || lowerName.includes('asn') || lowerName.includes('network') || attrType.includes('network')) { + categories['Network Info'].push(attr); + } else if (lowerName.includes('shodan') || lowerName.includes('crtsh') || lowerName.includes('provider') || attrType.includes('provider')) { + categories['Provider Data'].push(attr); + } else { + categories['Other'].push(attr); + } + }); - // Handle certificate summary - if (certData.total_certificates) { - html += ` -
-
- Total Certificates: ${certData.total_certificates} - - ${certData.has_valid_cert ? 'Valid' : 'Invalid'} - -
-
- `; - } - - // Handle unique issuers - if (certData.unique_issuers && Array.isArray(certData.unique_issuers)) { - html += ` -
-
- Issuers: -
-
- `; - certData.unique_issuers.forEach(issuer => { - html += `
${this.escapeHtml(String(issuer))}
`; - }); - html += '
'; - } - - html += '
'; - return html; + return categories; } - formatAttributeValue(value) { + /** + * UPDATED: Format StandardAttribute value for display + */ + formatStandardAttributeValue(attr) { + const value = attr.value; + if (value === null || value === undefined) { return 'None'; } @@ -1567,35 +1452,6 @@ class DNSReconApp { return this.escapeHtml(String(value)); } - - categorizeAttributes(attributes) { - const categories = { - 'DNS Records': {}, - 'Certificates': {}, - 'Network Info': {}, - 'Provider Data': {}, - 'Other': {} - }; - - for (const [key, value] of Object.entries(attributes)) { - const lowerKey = key.toLowerCase(); - - if (lowerKey.includes('dns') || lowerKey.includes('record') || key.endsWith('_record')) { - categories['DNS Records'][key] = value; - } else if (lowerKey.includes('cert') || lowerKey.includes('ssl') || lowerKey.includes('tls')) { - categories['Certificates'][key] = value; - } else if (lowerKey.includes('ip') || lowerKey.includes('asn') || lowerKey.includes('network')) { - categories['Network Info'][key] = value; - } else if (lowerKey.includes('shodan') || lowerKey.includes('crtsh') || lowerKey.includes('provider')) { - categories['Provider Data'][key] = value; - } else { - categories['Other'][key] = value; - } - } - - return categories; - } - formatObjectCompact(obj) { if (!obj || typeof obj !== 'object') return ''; @@ -1625,7 +1481,7 @@ class DNSReconApp { return `
-

📝Description

+

📄Description

${this.escapeHtml(node.description)} @@ -1827,7 +1683,7 @@ class DNSReconApp { getNodeTypeIcon(nodeType) { const icons = { 'domain': '🌐', - 'ip': '📍', + 'ip': '🔍', 'asn': '🏢', 'large_entity': '📦', 'correlation_object': '🔗'