# dnsrecon/providers/base_provider.py import time import requests import threading import os import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Tuple from core.logger import get_forensic_logger from core.graph_manager import RelationshipType class RateLimiter: """Simple rate limiter for API calls.""" def __init__(self, requests_per_minute: int): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests allowed per minute """ self.requests_per_minute = requests_per_minute self.min_interval = 60.0 / requests_per_minute self.last_request_time = 0 def __getstate__(self): """RateLimiter is fully picklable, return full state.""" return self.__dict__.copy() def __setstate__(self, state): """Restore RateLimiter state.""" self.__dict__.update(state) def wait_if_needed(self) -> None: """Wait if necessary to respect rate limits.""" current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.min_interval: sleep_time = self.min_interval - time_since_last time.sleep(sleep_time) self.last_request_time = time.time() class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. Now supports session-specific configuration. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): """ Initialize base provider with session-specific configuration. Args: name: Provider name for logging rate_limit: Requests per minute limit (default override) timeout: Request timeout in seconds session_config: Session-specific configuration """ # Use session config if provided, otherwise fall back to global config if session_config is not None: self.config = session_config actual_rate_limit = self.config.get_rate_limit(name) actual_timeout = self.config.default_timeout else: # Fallback to global config for backwards compatibility from config import config as global_config self.config = global_config actual_rate_limit = rate_limit actual_timeout = timeout self.name = name self.rate_limiter = RateLimiter(actual_rate_limit) self.timeout = actual_timeout self._local = threading.local() self.logger = get_forensic_logger() self._stop_event = None # Caching configuration (per session) self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config self.cache_expiry = 12 * 3600 # 12 hours in seconds if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) # Statistics (per provider instance) self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)") def __getstate__(self): state = self.__dict__.copy() # Exclude the unpickleable '_local' attribute if '_local' in state: del state['_local'] return state def __setstate__(self, state): self.__dict__.update(state) # Re-initialize the '_local' attribute self._local = threading.local() @property def session(self): if not hasattr(self._local, 'session'): self._local.session = requests.Session() self._local.session.headers.update({ 'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)' }) return self._local.session @abstractmethod def get_name(self) -> str: """Return the provider name.""" pass @abstractmethod def is_available(self) -> bool: """Check if the provider is available and properly configured.""" pass @abstractmethod def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query the provider for information about a domain. Args: domain: Domain to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass @abstractmethod def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]: """ Query the provider for information about an IP address. Args: ip: IP address to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass def make_request(self, url: str, method: str = "GET", params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, target_indicator: str = "", max_retries: int = 3) -> Optional[requests.Response]: """ Make a rate-limited HTTP request with forensic logging and retry logic. Now supports cancellation via stop_event from scanner. """ # Check for cancellation before starting if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): print(f"Request cancelled before start: {url}") return None # Create a unique cache key cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" cache_path = os.path.join(self.cache_dir, cache_key) # Check cache if os.path.exists(cache_path): cache_age = time.time() - os.path.getmtime(cache_path) if cache_age < self.cache_expiry: print(f"Returning cached response for: {url}") with open(cache_path, 'r') as f: cached_data = json.load(f) response = requests.Response() response.status_code = cached_data['status_code'] response._content = cached_data['content'].encode('utf-8') response.headers = cached_data['headers'] return response for attempt in range(max_retries + 1): # Check for cancellation before each attempt if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): print(f"Request cancelled during attempt {attempt + 1}: {url}") return None # Apply rate limiting (but reduce wait time if cancellation is requested) if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): break self.rate_limiter.wait_if_needed() # Check again after rate limiting if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): print(f"Request cancelled after rate limiting: {url}") return None start_time = time.time() response = None error = None try: self.total_requests += 1 # Prepare request request_headers = self.session.headers.copy() if headers: request_headers.update(headers) print(f"Making {method} request to: {url} (attempt {attempt + 1})") # Use shorter timeout if termination is requested request_timeout = self.timeout if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested # Make request if method.upper() == "GET": response = self.session.get( url, params=params, headers=request_headers, timeout=request_timeout ) elif method.upper() == "POST": response = self.session.post( url, json=params, headers=request_headers, timeout=request_timeout ) else: raise ValueError(f"Unsupported HTTP method: {method}") print(f"Response status: {response.status_code}") response.raise_for_status() self.successful_requests += 1 # Success - log, cache, and return duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code, response_size=len(response.content), duration_ms=duration_ms, error=None, target_indicator=target_indicator ) # Cache the successful response to disk with open(cache_path, 'w') as f: json.dump({ 'status_code': response.status_code, 'content': response.text, 'headers': dict(response.headers) }, f) return response except requests.exceptions.RequestException as e: error = str(e) self.failed_requests += 1 print(f"Request failed (attempt {attempt + 1}): {error}") # Check for cancellation before retrying if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): print(f"Request cancelled, not retrying: {url}") break # Check if we should retry if attempt < max_retries and self._should_retry(e): backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s print(f"Retrying in {backoff_time} seconds...") # Shorter backoff if termination is requested if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): backoff_time = min(0.5, backoff_time) # Sleep with cancellation checking sleep_start = time.time() while time.time() - sleep_start < backoff_time: if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): print(f"Request cancelled during backoff: {url}") return None time.sleep(0.1) # Check every 100ms continue else: break except Exception as e: error = f"Unexpected error: {str(e)}" self.failed_requests += 1 print(f"Unexpected error: {error}") break # All attempts failed - log and return None duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code if response else None, response_size=len(response.content) if response else None, duration_ms=duration_ms, error=error, target_indicator=target_indicator ) return None def set_stop_event(self, stop_event: threading.Event) -> None: """ Set the stop event for this provider to enable cancellation. Args: stop_event: Threading event to signal cancellation """ self._stop_event = stop_event def _should_retry(self, exception: requests.exceptions.RequestException) -> bool: """ Determine if a request should be retried based on the exception. Args: exception: The request exception that occurred Returns: True if the request should be retried """ # Retry on connection errors, timeouts, and 5xx server errors if isinstance(exception, (requests.exceptions.ConnectionError, requests.exceptions.Timeout)): return True if isinstance(exception, requests.exceptions.HTTPError): if hasattr(exception, 'response') and exception.response: # Retry on server errors (5xx) but not client errors (4xx) return exception.response.status_code >= 500 return False def log_relationship_discovery(self, source_node: str, target_node: str, relationship_type: RelationshipType, confidence_score: float, raw_data: Dict[str, Any], discovery_method: str) -> None: """ Log discovery of a new relationship. Args: source_node: Source node identifier target_node: Target node identifier relationship_type: Type of relationship confidence_score: Confidence score raw_data: Raw data from provider discovery_method: Method used for discovery """ self.total_relationships_found += 1 self.logger.log_relationship_discovery( source_node=source_node, target_node=target_node, relationship_type=relationship_type.relationship_name, confidence_score=confidence_score, provider=self.name, raw_data=raw_data, discovery_method=discovery_method ) def get_statistics(self) -> Dict[str, Any]: """ Get provider statistics. Returns: Dictionary containing provider performance metrics """ return { 'name': self.name, 'total_requests': self.total_requests, 'successful_requests': self.successful_requests, 'failed_requests': self.failed_requests, 'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0, 'relationships_found': self.total_relationships_found, 'rate_limit': self.rate_limiter.requests_per_minute }