""" Abstract base provider class for DNSRecon data sources. Defines the interface and common functionality for all providers. """ import time import requests import threading from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from core.logger import get_forensic_logger from core.graph_manager import NodeType, RelationshipType class RateLimiter: """Simple rate limiter for API calls.""" def __init__(self, requests_per_minute: int): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests allowed per minute """ self.requests_per_minute = requests_per_minute self.min_interval = 60.0 / requests_per_minute self.last_request_time = 0 def wait_if_needed(self) -> None: """Wait if necessary to respect rate limits.""" current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.min_interval: sleep_time = self.min_interval - time_since_last time.sleep(sleep_time) self.last_request_time = time.time() class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. Provides common functionality and defines the provider interface. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30): """ Initialize base provider. Args: name: Provider name for logging rate_limit: Requests per minute limit timeout: Request timeout in seconds """ self.name = name self.rate_limiter = RateLimiter(rate_limit) self.timeout = timeout self._local = threading.local() self.logger = get_forensic_logger() # Statistics self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 @property def session(self): if not hasattr(self._local, 'session'): self._local.session = requests.Session() self._local.session.headers.update({ 'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)' }) return self._local.session @abstractmethod def get_name(self) -> str: """Return the provider name.""" pass @abstractmethod def is_available(self) -> bool: """Check if the provider is available and properly configured.""" pass @abstractmethod def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query the provider for information about a domain. Args: domain: Domain to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass @abstractmethod def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query the provider for information about an IP address. Args: ip: IP address to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass def make_request(self, url: str, method: str = "GET", params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, target_indicator: str = "") -> Optional[requests.Response]: """ Make a rate-limited HTTP request with forensic logging. Args: url: Request URL method: HTTP method params: Query parameters headers: Additional headers target_indicator: The indicator being investigated Returns: Response object or None if request failed """ # Apply rate limiting self.rate_limiter.wait_if_needed() start_time = time.time() response = None error = None try: self.total_requests += 1 # Prepare request request_headers = self.session.headers.copy() if headers: request_headers.update(headers) print(f"Making {method} request to: {url}") # Make request if method.upper() == "GET": response = self.session.get( url, params=params, headers=request_headers, timeout=self.timeout ) elif method.upper() == "POST": response = self.session.post( url, json=params, headers=request_headers, timeout=self.timeout ) else: raise ValueError(f"Unsupported HTTP method: {method}") print(f"Response status: {response.status_code}") response.raise_for_status() self.successful_requests += 1 except requests.exceptions.RequestException as e: error = str(e) self.failed_requests += 1 print(f"Request failed: {error}") except Exception as e: error = f"Unexpected error: {str(e)}" self.failed_requests += 1 print(f"Unexpected error: {error}") # Calculate duration and log request duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code if response else None, response_size=len(response.content) if response else None, duration_ms=duration_ms, error=error, target_indicator=target_indicator ) return response if error is None else None def log_relationship_discovery(self, source_node: str, target_node: str, relationship_type: RelationshipType, confidence_score: float, raw_data: Dict[str, Any], discovery_method: str) -> None: """ Log discovery of a new relationship. Args: source_node: Source node identifier target_node: Target node identifier relationship_type: Type of relationship confidence_score: Confidence score raw_data: Raw data from provider discovery_method: Method used for discovery """ self.total_relationships_found += 1 self.logger.log_relationship_discovery( source_node=source_node, target_node=target_node, relationship_type=relationship_type.relationship_name, confidence_score=confidence_score, provider=self.name, raw_data=raw_data, discovery_method=discovery_method ) def get_statistics(self) -> Dict[str, Any]: """ Get provider statistics. Returns: Dictionary containing provider performance metrics """ return { 'name': self.name, 'total_requests': self.total_requests, 'successful_requests': self.successful_requests, 'failed_requests': self.failed_requests, 'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0, 'relationships_found': self.total_relationships_found, 'rate_limit': self.rate_limiter.requests_per_minute } def reset_statistics(self) -> None: """Reset provider statistics.""" self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 def _extract_domain_from_url(self, url: str) -> Optional[str]: """ Extract domain from URL. Args: url: URL string Returns: Domain name or None if extraction fails """ try: # Remove protocol if '://' in url: url = url.split('://', 1)[1] # Remove path if '/' in url: url = url.split('/', 1)[0] # Remove port if ':' in url: url = url.split(':', 1)[0] return url.lower() except Exception: return None def _is_valid_domain(self, domain: str) -> bool: """ Basic domain validation. Args: domain: Domain string to validate Returns: True if domain appears valid """ if not domain or len(domain) > 253: return False # Check for valid characters and structure parts = domain.split('.') if len(parts) < 2: return False for part in parts: if not part or len(part) > 63: return False if not part.replace('-', '').replace('_', '').isalnum(): return False return True def _is_valid_ip(self, ip: str) -> bool: """ Basic IP address validation. Args: ip: IP address string to validate Returns: True if IP appears valid """ try: parts = ip.split('.') if len(parts) != 4: return False for part in parts: num = int(part) if not 0 <= num <= 255: return False return True except (ValueError, AttributeError): return False