# dnsrecon/providers/base_provider.py import time import requests import threading import os import json import hashlib from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Tuple from datetime import datetime, timezone from core.logger import get_forensic_logger class RateLimiter: """Thread-safe rate limiter for API calls.""" def __init__(self, requests_per_minute: int): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests allowed per minute """ self.requests_per_minute = requests_per_minute self.min_interval = 60.0 / requests_per_minute self.last_request_time = 0 self._lock = threading.Lock() def __getstate__(self): """RateLimiter is fully picklable, return full state.""" state = self.__dict__.copy() # Exclude unpickleable lock if '_lock' in state: del state['_lock'] return state def __setstate__(self, state): """Restore RateLimiter state.""" self.__dict__.update(state) self._lock = threading.Lock() def wait_if_needed(self) -> None: """Wait if necessary to respect rate limits.""" with self._lock: current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.min_interval: sleep_time = self.min_interval - time_since_last time.sleep(sleep_time) self.last_request_time = time.time() class ProviderCache: """Thread-safe global cache for provider queries.""" def __init__(self, provider_name: str, cache_expiry_hours: int = 12): """ Initialize provider-specific cache. Args: provider_name: Name of the provider for cache directory cache_expiry_hours: Cache expiry time in hours """ self.provider_name = provider_name self.cache_expiry = cache_expiry_hours * 3600 # Convert to seconds self.cache_dir = os.path.join('.cache', provider_name) self._lock = threading.Lock() # Ensure cache directory exists with thread-safe creation os.makedirs(self.cache_dir, exist_ok=True) def _generate_cache_key(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> str: """Generate unique cache key for request.""" cache_data = f"{method}:{url}:{json.dumps(params or {}, sort_keys=True)}" return hashlib.md5(cache_data.encode()).hexdigest() + ".json" def get_cached_response(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> Optional[requests.Response]: """ Retrieve cached response if available and not expired. Returns: Cached Response object or None if cache miss/expired """ cache_key = self._generate_cache_key(method, url, params) cache_path = os.path.join(self.cache_dir, cache_key) with self._lock: if not os.path.exists(cache_path): return None # Check if cache is expired cache_age = time.time() - os.path.getmtime(cache_path) if cache_age >= self.cache_expiry: try: os.remove(cache_path) except OSError: pass # File might have been removed by another thread return None try: with open(cache_path, 'r', encoding='utf-8') as f: cached_data = json.load(f) # Reconstruct Response object response = requests.Response() response.status_code = cached_data['status_code'] response._content = cached_data['content'].encode('utf-8') response.headers.update(cached_data['headers']) return response except (json.JSONDecodeError, KeyError, IOError) as e: # Cache file corrupted, remove it try: os.remove(cache_path) except OSError: pass return None def cache_response(self, method: str, url: str, params: Optional[Dict[str, Any]], response: requests.Response) -> bool: """ Cache successful response to disk. Returns: True if cached successfully, False otherwise """ if response.status_code != 200: return False cache_key = self._generate_cache_key(method, url, params) cache_path = os.path.join(self.cache_dir, cache_key) with self._lock: try: cache_data = { 'status_code': response.status_code, 'content': response.text, 'headers': dict(response.headers), 'cached_at': datetime.now(timezone.utc).isoformat() } # Write to temporary file first, then rename for atomic operation temp_path = cache_path + '.tmp' with open(temp_path, 'w', encoding='utf-8') as f: json.dump(cache_data, f) # Atomic rename to prevent partial cache files os.rename(temp_path, cache_path) return True except (IOError, OSError) as e: # Clean up temp file if it exists try: if os.path.exists(temp_path): os.remove(temp_path) except OSError: pass return False class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. Now supports global provider-specific caching and session-specific configuration. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): """ Initialize base provider with global caching and session-specific configuration. Args: name: Provider name for logging rate_limit: Requests per minute limit (default override) timeout: Request timeout in seconds session_config: Session-specific configuration """ # Use session config if provided, otherwise fall back to global config if session_config is not None: self.config = session_config actual_rate_limit = self.config.get_rate_limit(name) actual_timeout = self.config.default_timeout else: # Fallback to global config for backwards compatibility from config import config as global_config self.config = global_config actual_rate_limit = rate_limit actual_timeout = timeout self.name = name self.rate_limiter = RateLimiter(actual_rate_limit) self.timeout = actual_timeout self._local = threading.local() self.logger = get_forensic_logger() self._stop_event = None # GLOBAL provider-specific caching (not session-based) self.cache = ProviderCache(name, cache_expiry_hours=12) # Statistics (per provider instance) self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 self.cache_hits = 0 self.cache_misses = 0 print(f"Initialized {name} provider with global cache and session config (rate: {actual_rate_limit}/min)") def __getstate__(self): """Prepare BaseProvider for pickling by excluding unpicklable objects.""" state = self.__dict__.copy() # Exclude the unpickleable '_local' attribute and stop event state['_local'] = None state['_stop_event'] = None return state def __setstate__(self, state): """Restore BaseProvider after unpickling by reconstructing threading objects.""" self.__dict__.update(state) # Re-initialize the '_local' attribute and stop event self._local = threading.local() self._stop_event = None @property def session(self): if not hasattr(self._local, 'session'): self._local.session = requests.Session() self._local.session.headers.update({ 'User-Agent': 'DNSRecon/2.0 (Passive Reconnaissance Tool)' }) return self._local.session @abstractmethod def get_name(self) -> str: """Return the provider name.""" pass @abstractmethod def get_display_name(self) -> str: """Return the provider display name for the UI.""" pass @abstractmethod def requires_api_key(self) -> bool: """Return True if the provider requires an API key.""" pass @abstractmethod def get_eligibility(self) -> Dict[str, bool]: """Return a dictionary indicating if the provider can query domains and/or IPs.""" pass @abstractmethod def is_available(self) -> bool: """Check if the provider is available and properly configured.""" pass @abstractmethod def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: """ Query the provider for information about a domain. Args: domain: Domain to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass @abstractmethod def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: """ Query the provider for information about an IP address. Args: ip: IP address to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass def make_request(self, url: str, method: str = "GET", params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, target_indicator: str = "", max_retries: int = 3) -> Optional[requests.Response]: """ Make a rate-limited HTTP request with global caching and aggressive stop signal handling. """ # Check for cancellation before starting if self._is_stop_requested(): print(f"Request cancelled before start: {url}") return None # Check global cache first cached_response = self.cache.get_cached_response(method, url, params) if cached_response is not None: print(f"Cache hit for {self.name}: {url}") self.cache_hits += 1 return cached_response self.cache_misses += 1 # Determine effective max_retries based on stop signal effective_max_retries = 0 if self._is_stop_requested() else max_retries last_exception = None for attempt in range(effective_max_retries + 1): # Check for cancellation before each attempt if self._is_stop_requested(): print(f"Request cancelled during attempt {attempt + 1}: {url}") return None # Apply rate limiting with cancellation awareness if not self._wait_with_cancellation_check(): print(f"Request cancelled during rate limiting: {url}") return None # Final check before making HTTP request if self._is_stop_requested(): print(f"Request cancelled before HTTP call: {url}") return None start_time = time.time() response = None error = None try: self.total_requests += 1 # Prepare request request_headers = self.session.headers.copy() if headers: request_headers.update(headers) print(f"Making {method} request to: {url} (attempt {attempt + 1})") # Use shorter timeout if termination is requested request_timeout = 2 if self._is_stop_requested() else self.timeout # Make request if method.upper() == "GET": response = self.session.get( url, params=params, headers=request_headers, timeout=request_timeout ) elif method.upper() == "POST": response = self.session.post( url, json=params, headers=request_headers, timeout=request_timeout ) else: raise ValueError(f"Unsupported HTTP method: {method}") print(f"Response status: {response.status_code}") response.raise_for_status() self.successful_requests += 1 # Success - log, cache, and return duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code, response_size=len(response.content), duration_ms=duration_ms, error=None, target_indicator=target_indicator ) # Cache the successful response globally self.cache.cache_response(method, url, params, response) return response except requests.exceptions.RequestException as e: error = str(e) self.failed_requests += 1 print(f"Request failed (attempt {attempt + 1}): {error}") last_exception = e # Immediately abort retries if stop requested if self._is_stop_requested(): print(f"Stop requested - aborting retries for: {url}") break # Check if we should retry if attempt < effective_max_retries and self._should_retry(e): # Exponential backoff with jitter for 429 errors if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: backoff_time = min(60, 10 * (2 ** attempt)) print(f"Rate limit hit. Retrying in {backoff_time} seconds...") else: backoff_time = min(2.0, (2 ** attempt) * 0.5) print(f"Retrying in {backoff_time} seconds...") if not self._sleep_with_cancellation_check(backoff_time): print(f"Stop requested during backoff - aborting: {url}") return None continue else: break except Exception as e: error = f"Unexpected error: {str(e)}" self.failed_requests += 1 print(f"Unexpected error: {error}") last_exception = e break # All attempts failed - log and return None duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code if response else None, response_size=len(response.content) if response else None, duration_ms=duration_ms, error=error, target_indicator=target_indicator ) if error and last_exception: raise last_exception return None def _is_stop_requested(self) -> bool: """ Enhanced stop signal checking that handles both local and Redis-based signals. """ if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): return True return False def _wait_with_cancellation_check(self) -> bool: """ Wait for rate limiting while aggressively checking for cancellation. Returns False if cancelled during wait. """ current_time = time.time() time_since_last = current_time - self.rate_limiter.last_request_time if time_since_last < self.rate_limiter.min_interval: sleep_time = self.rate_limiter.min_interval - time_since_last if not self._sleep_with_cancellation_check(sleep_time): return False self.rate_limiter.last_request_time = time.time() return True def _sleep_with_cancellation_check(self, sleep_time: float) -> bool: """ Sleep for the specified time while aggressively checking for cancellation. Args: sleep_time: Time to sleep in seconds Returns: bool: True if sleep completed, False if cancelled """ sleep_start = time.time() check_interval = 0.05 # Check every 50ms for aggressive responsiveness while time.time() - sleep_start < sleep_time: if self._is_stop_requested(): return False remaining_time = sleep_time - (time.time() - sleep_start) time.sleep(min(check_interval, remaining_time)) return True def set_stop_event(self, stop_event: threading.Event) -> None: """ Set the stop event for this provider to enable cancellation. Args: stop_event: Threading event to signal cancellation """ self._stop_event = stop_event def _should_retry(self, exception: requests.exceptions.RequestException) -> bool: """ Determine if a request should be retried based on the exception. Args: exception: The request exception that occurred Returns: True if the request should be retried """ # Retry on connection errors and timeouts if isinstance(exception, (requests.exceptions.ConnectionError, requests.exceptions.Timeout)): return True if isinstance(exception, requests.exceptions.HTTPError): if hasattr(exception, 'response') and exception.response: # Retry on server errors (5xx) AND on rate-limiting errors (429) return exception.response.status_code >= 500 or exception.response.status_code == 429 return False def log_relationship_discovery(self, source_node: str, target_node: str, relationship_type: str, confidence_score: float, raw_data: Dict[str, Any], discovery_method: str) -> None: """ Log discovery of a new relationship. Args: source_node: Source node identifier target_node: Target node identifier relationship_type: Type of relationship confidence_score: Confidence score raw_data: Raw data from provider discovery_method: Method used for discovery """ self.total_relationships_found += 1 self.logger.log_relationship_discovery( source_node=source_node, target_node=target_node, relationship_type=relationship_type, confidence_score=confidence_score, provider=self.name, raw_data=raw_data, discovery_method=discovery_method ) def get_statistics(self) -> Dict[str, Any]: """ Get provider statistics including cache performance. Returns: Dictionary containing provider performance metrics """ return { 'name': self.name, 'total_requests': self.total_requests, 'successful_requests': self.successful_requests, 'failed_requests': self.failed_requests, 'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0, 'relationships_found': self.total_relationships_found, 'rate_limit': self.rate_limiter.requests_per_minute, 'cache_hits': self.cache_hits, 'cache_misses': self.cache_misses, 'cache_hit_rate': (self.cache_hits / (self.cache_hits + self.cache_misses) * 100) if (self.cache_hits + self.cache_misses) > 0 else 0 }