# dnsrecon/providers/base_provider.py import time import requests import threading import os import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Tuple from core.logger import get_forensic_logger class RateLimiter: """Simple rate limiter for API calls.""" def __init__(self, requests_per_minute: int): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests allowed per minute """ self.requests_per_minute = requests_per_minute self.min_interval = 60.0 / requests_per_minute self.last_request_time = 0 def __getstate__(self): """RateLimiter is fully picklable, return full state.""" return self.__dict__.copy() def __setstate__(self, state): """Restore RateLimiter state.""" self.__dict__.update(state) def wait_if_needed(self) -> None: """Wait if necessary to respect rate limits.""" current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.min_interval: sleep_time = self.min_interval - time_since_last time.sleep(sleep_time) self.last_request_time = time.time() class BaseProvider(ABC): """ Abstract base class for all DNSRecon data providers. Now supports session-specific configuration. """ def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None): """ Initialize base provider with session-specific configuration. Args: name: Provider name for logging rate_limit: Requests per minute limit (default override) timeout: Request timeout in seconds session_config: Session-specific configuration """ # Use session config if provided, otherwise fall back to global config if session_config is not None: self.config = session_config actual_rate_limit = self.config.get_rate_limit(name) actual_timeout = self.config.default_timeout else: # Fallback to global config for backwards compatibility from config import config as global_config self.config = global_config actual_rate_limit = rate_limit actual_timeout = timeout self.name = name self.rate_limiter = RateLimiter(actual_rate_limit) self.timeout = actual_timeout self._local = threading.local() self.logger = get_forensic_logger() self._stop_event = None # Caching configuration (per session) self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config self.cache_expiry = self.config.cache_expiry_hours * 3600 if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) # Statistics (per provider instance) self.total_requests = 0 self.successful_requests = 0 self.failed_requests = 0 self.total_relationships_found = 0 def __getstate__(self): """Prepare BaseProvider for pickling by excluding unpicklable objects.""" state = self.__dict__.copy() # Exclude the unpickleable '_local' attribute and stop event unpicklable_attrs = ['_local', '_stop_event'] for attr in unpicklable_attrs: if attr in state: del state[attr] return state def __setstate__(self, state): """Restore BaseProvider after unpickling by reconstructing threading objects.""" self.__dict__.update(state) # Re-initialize the '_local' attribute and stop event self._local = threading.local() self._stop_event = None @property def session(self): if not hasattr(self._local, 'session'): self._local.session = requests.Session() self._local.session.headers.update({ 'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)' }) return self._local.session @abstractmethod def get_name(self) -> str: """Return the provider name.""" pass @abstractmethod def get_display_name(self) -> str: """Return the provider display name for the UI.""" pass @abstractmethod def requires_api_key(self) -> bool: """Return True if the provider requires an API key.""" pass @abstractmethod def get_eligibility(self) -> Dict[str, bool]: """Return a dictionary indicating if the provider can query domains and/or IPs.""" pass @abstractmethod def is_available(self) -> bool: """Check if the provider is available and properly configured.""" pass @abstractmethod def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: """ Query the provider for information about a domain. Args: domain: Domain to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass @abstractmethod def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: """ Query the provider for information about an IP address. Args: ip: IP address to investigate Returns: List of tuples: (source_node, target_node, relationship_type, confidence, raw_data) """ pass def make_request(self, url: str, method: str = "GET", params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, target_indicator: str = "", max_retries: int = 3) -> Optional[requests.Response]: """ Make a rate-limited HTTP request with aggressive stop signal handling. Terminates immediately when stop is requested, including during retries. """ # Check for cancellation before starting if self._is_stop_requested(): print(f"Request cancelled before start: {url}") return None # Create a unique cache key cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" cache_path = os.path.join(self.cache_dir, cache_key) # Check cache if os.path.exists(cache_path): cache_age = time.time() - os.path.getmtime(cache_path) if cache_age < self.cache_expiry: print(f"Returning cached response for: {url}") with open(cache_path, 'r') as f: cached_data = json.load(f) response = requests.Response() response.status_code = cached_data['status_code'] response._content = cached_data['content'].encode('utf-8') response.headers = cached_data['headers'] return response # Determine effective max_retries based on stop signal effective_max_retries = 0 if self._is_stop_requested() else max_retries last_exception = None for attempt in range(effective_max_retries + 1): # AGGRESSIVE: Check for cancellation before each attempt if self._is_stop_requested(): print(f"Request cancelled during attempt {attempt + 1}: {url}") return None # Apply rate limiting with cancellation awareness if not self._wait_with_cancellation_check(): print(f"Request cancelled during rate limiting: {url}") return None # AGGRESSIVE: Final check before making HTTP request if self._is_stop_requested(): print(f"Request cancelled before HTTP call: {url}") return None start_time = time.time() response = None error = None try: self.total_requests += 1 # Prepare request request_headers = self.session.headers.copy() if headers: request_headers.update(headers) print(f"Making {method} request to: {url} (attempt {attempt + 1})") # AGGRESSIVE: Use much shorter timeout if termination is requested request_timeout = self.timeout if self._is_stop_requested(): request_timeout = 2 # Max 2 seconds if termination requested print(f"Stop requested - using short timeout: {request_timeout}s") # Make request if method.upper() == "GET": response = self.session.get( url, params=params, headers=request_headers, timeout=request_timeout ) elif method.upper() == "POST": response = self.session.post( url, json=params, headers=request_headers, timeout=request_timeout ) else: raise ValueError(f"Unsupported HTTP method: {method}") print(f"Response status: {response.status_code}") response.raise_for_status() self.successful_requests += 1 # Success - log, cache, and return duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code, response_size=len(response.content), duration_ms=duration_ms, error=None, target_indicator=target_indicator ) # Cache the successful response to disk with open(cache_path, 'w') as f: json.dump({ 'status_code': response.status_code, 'content': response.text, 'headers': dict(response.headers) }, f) return response except requests.exceptions.RequestException as e: error = str(e) self.failed_requests += 1 print(f"Request failed (attempt {attempt + 1}): {error}") last_exception = e # AGGRESSIVE: Immediately abort retries if stop requested if self._is_stop_requested(): print(f"Stop requested - aborting retries for: {url}") break # Check if we should retry (but only if stop not requested) if attempt < effective_max_retries and self._should_retry(e): # Use a longer, more respectful backoff for 429 errors if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: # Start with a 10-second backoff and increase exponentially backoff_time = 10 * (2 ** attempt) print(f"Rate limit hit. Retrying in {backoff_time} seconds...") else: backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors print(f"Retrying in {backoff_time} seconds...") # AGGRESSIVE: Much shorter backoff and more frequent checking if not self._sleep_with_cancellation_check(backoff_time): print(f"Stop requested during backoff - aborting: {url}") return None continue else: break except Exception as e: error = f"Unexpected error: {str(e)}" self.failed_requests += 1 print(f"Unexpected error: {error}") last_exception = e break # All attempts failed - log and return None duration_ms = (time.time() - start_time) * 1000 self.logger.log_api_request( provider=self.name, url=url, method=method.upper(), status_code=response.status_code if response else None, response_size=len(response.content) if response else None, duration_ms=duration_ms, error=error, target_indicator=target_indicator ) if error and last_exception: raise last_exception return None def _is_stop_requested(self) -> bool: """ Enhanced stop signal checking that handles both local and Redis-based signals. """ if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set(): return True return False def _wait_with_cancellation_check(self) -> bool: """ Wait for rate limiting while aggressively checking for cancellation. Returns False if cancelled during wait. """ current_time = time.time() time_since_last = current_time - self.rate_limiter.last_request_time if time_since_last < self.rate_limiter.min_interval: sleep_time = self.rate_limiter.min_interval - time_since_last if not self._sleep_with_cancellation_check(sleep_time): return False self.rate_limiter.last_request_time = time.time() return True def _sleep_with_cancellation_check(self, sleep_time: float) -> bool: """ Sleep for the specified time while aggressively checking for cancellation. Args: sleep_time: Time to sleep in seconds Returns: bool: True if sleep completed, False if cancelled """ sleep_start = time.time() check_interval = 0.05 # Check every 50ms for aggressive responsiveness while time.time() - sleep_start < sleep_time: if self._is_stop_requested(): return False remaining_time = sleep_time - (time.time() - sleep_start) time.sleep(min(check_interval, remaining_time)) return True def set_stop_event(self, stop_event: threading.Event) -> None: """ Set the stop event for this provider to enable cancellation. Args: stop_event: Threading event to signal cancellation """ self._stop_event = stop_event def _should_retry(self, exception: requests.exceptions.RequestException) -> bool: """ Determine if a request should be retried based on the exception. Args: exception: The request exception that occurred Returns: True if the request should be retried """ # Retry on connection errors and timeouts if isinstance(exception, (requests.exceptions.ConnectionError, requests.exceptions.Timeout)): return True if isinstance(exception, requests.exceptions.HTTPError): if hasattr(exception, 'response') and exception.response: # Retry on server errors (5xx) AND on rate-limiting errors (429) return exception.response.status_code >= 500 or exception.response.status_code == 429 return False def log_relationship_discovery(self, source_node: str, target_node: str, relationship_type: str, confidence_score: float, raw_data: Dict[str, Any], discovery_method: str) -> None: """ Log discovery of a new relationship. Args: source_node: Source node identifier target_node: Target node identifier relationship_type: Type of relationship confidence_score: Confidence score raw_data: Raw data from provider discovery_method: Method used for discovery """ self.total_relationships_found += 1 self.logger.log_relationship_discovery( source_node=source_node, target_node=target_node, relationship_type=relationship_type, confidence_score=confidence_score, provider=self.name, raw_data=raw_data, discovery_method=discovery_method ) def get_statistics(self) -> Dict[str, Any]: """ Get provider statistics. Returns: Dictionary containing provider performance metrics """ return { 'name': self.name, 'total_requests': self.total_requests, 'successful_requests': self.successful_requests, 'failed_requests': self.failed_requests, 'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0, 'relationships_found': self.total_relationships_found, 'rate_limit': self.rate_limiter.requests_per_minute }