dnsrecon/providers/base_provider.py
overcuriousity 4378146d0c it
2025-09-14 13:14:02 +02:00

562 lines
21 KiB
Python

# dnsrecon/providers/base_provider.py
import time
import requests
import threading
import os
import json
import hashlib
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime, timezone
from core.logger import get_forensic_logger
class RateLimiter:
"""Thread-safe rate limiter for API calls."""
def __init__(self, requests_per_minute: int):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests allowed per minute
"""
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request_time = 0
self._lock = threading.Lock()
def __getstate__(self):
"""RateLimiter is fully picklable, return full state."""
state = self.__dict__.copy()
# Exclude unpickleable lock
if '_lock' in state:
del state['_lock']
return state
def __setstate__(self, state):
"""Restore RateLimiter state."""
self.__dict__.update(state)
self._lock = threading.Lock()
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
with self._lock:
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
class ProviderCache:
"""Thread-safe global cache for provider queries."""
def __init__(self, provider_name: str, cache_expiry_hours: int = 12):
"""
Initialize provider-specific cache.
Args:
provider_name: Name of the provider for cache directory
cache_expiry_hours: Cache expiry time in hours
"""
self.provider_name = provider_name
self.cache_expiry = cache_expiry_hours * 3600 # Convert to seconds
self.cache_dir = os.path.join('.cache', provider_name)
self._lock = threading.Lock()
# Ensure cache directory exists with thread-safe creation
os.makedirs(self.cache_dir, exist_ok=True)
def _generate_cache_key(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> str:
"""Generate unique cache key for request."""
cache_data = f"{method}:{url}:{json.dumps(params or {}, sort_keys=True)}"
return hashlib.md5(cache_data.encode()).hexdigest() + ".json"
def get_cached_response(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> Optional[requests.Response]:
"""
Retrieve cached response if available and not expired.
Returns:
Cached Response object or None if cache miss/expired
"""
cache_key = self._generate_cache_key(method, url, params)
cache_path = os.path.join(self.cache_dir, cache_key)
with self._lock:
if not os.path.exists(cache_path):
return None
# Check if cache is expired
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age >= self.cache_expiry:
try:
os.remove(cache_path)
except OSError:
pass # File might have been removed by another thread
return None
try:
with open(cache_path, 'r', encoding='utf-8') as f:
cached_data = json.load(f)
# Reconstruct Response object
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers.update(cached_data['headers'])
return response
except (json.JSONDecodeError, KeyError, IOError) as e:
# Cache file corrupted, remove it
try:
os.remove(cache_path)
except OSError:
pass
return None
def cache_response(self, method: str, url: str, params: Optional[Dict[str, Any]],
response: requests.Response) -> bool:
"""
Cache successful response to disk.
Returns:
True if cached successfully, False otherwise
"""
if response.status_code != 200:
return False
cache_key = self._generate_cache_key(method, url, params)
cache_path = os.path.join(self.cache_dir, cache_key)
with self._lock:
try:
cache_data = {
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers),
'cached_at': datetime.now(timezone.utc).isoformat()
}
# Write to temporary file first, then rename for atomic operation
temp_path = cache_path + '.tmp'
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(cache_data, f)
# Atomic rename to prevent partial cache files
os.rename(temp_path, cache_path)
return True
except (IOError, OSError) as e:
# Clean up temp file if it exists
try:
if os.path.exists(temp_path):
os.remove(temp_path)
except OSError:
pass
return False
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Now supports global provider-specific caching and session-specific configuration.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
"""
Initialize base provider with global caching and session-specific configuration.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit (default override)
timeout: Request timeout in seconds
session_config: Session-specific configuration
"""
# Use session config if provided, otherwise fall back to global config
if session_config is not None:
self.config = session_config
actual_rate_limit = self.config.get_rate_limit(name)
actual_timeout = self.config.default_timeout
else:
# Fallback to global config for backwards compatibility
from config import config as global_config
self.config = global_config
actual_rate_limit = rate_limit
actual_timeout = timeout
self.name = name
self.rate_limiter = RateLimiter(actual_rate_limit)
self.timeout = actual_timeout
self._local = threading.local()
self.logger = get_forensic_logger()
self._stop_event = None
# GLOBAL provider-specific caching (not session-based)
self.cache = ProviderCache(name, cache_expiry_hours=12)
# Statistics (per provider instance)
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
self.cache_hits = 0
self.cache_misses = 0
print(f"Initialized {name} provider with global cache and session config (rate: {actual_rate_limit}/min)")
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
state['_local'] = None
state['_stop_event'] = None
return state
def __setstate__(self, state):
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# Re-initialize the '_local' attribute and stop event
self._local = threading.local()
self._stop_event = None
@property
def session(self):
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({
'User-Agent': 'DNSRecon/2.0 (Passive Reconnaissance Tool)'
})
return self._local.session
@abstractmethod
def get_name(self) -> str:
"""Return the provider name."""
pass
@abstractmethod
def get_display_name(self) -> str:
"""Return the provider display name for the UI."""
pass
@abstractmethod
def requires_api_key(self) -> bool:
"""Return True if the provider requires an API key."""
pass
@abstractmethod
def get_eligibility(self) -> Dict[str, bool]:
"""Return a dictionary indicating if the provider can query domains and/or IPs."""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if the provider is available and properly configured."""
pass
@abstractmethod
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query the provider for information about a domain.
Args:
domain: Domain to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
@abstractmethod
def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query the provider for information about an IP address.
Args:
ip: IP address to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with global caching and aggressive stop signal handling.
"""
# Check for cancellation before starting
if self._is_stop_requested():
print(f"Request cancelled before start: {url}")
return None
# Check global cache first
cached_response = self.cache.get_cached_response(method, url, params)
if cached_response is not None:
print(f"Cache hit for {self.name}: {url}")
self.cache_hits += 1
return cached_response
self.cache_misses += 1
# Determine effective max_retries based on stop signal
effective_max_retries = 0 if self._is_stop_requested() else max_retries
last_exception = None
for attempt in range(effective_max_retries + 1):
# Check for cancellation before each attempt
if self._is_stop_requested():
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
# Apply rate limiting with cancellation awareness
if not self._wait_with_cancellation_check():
print(f"Request cancelled during rate limiting: {url}")
return None
# Final check before making HTTP request
if self._is_stop_requested():
print(f"Request cancelled before HTTP call: {url}")
return None
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
# Prepare request
request_headers = self.session.headers.copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# Use shorter timeout if termination is requested
request_timeout = 2 if self._is_stop_requested() else self.timeout
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=request_timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=request_timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
# Cache the successful response globally
self.cache.cache_response(method, url, params, response)
return response
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
last_exception = e
# Immediately abort retries if stop requested
if self._is_stop_requested():
print(f"Stop requested - aborting retries for: {url}")
break
# Check if we should retry
if attempt < effective_max_retries and self._should_retry(e):
# Exponential backoff with jitter for 429 errors
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
backoff_time = min(60, 10 * (2 ** attempt))
print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
else:
backoff_time = min(2.0, (2 ** attempt) * 0.5)
print(f"Retrying in {backoff_time} seconds...")
if not self._sleep_with_cancellation_check(backoff_time):
print(f"Stop requested during backoff - aborting: {url}")
return None
continue
else:
break
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
last_exception = e
break
# All attempts failed - log and return None
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
if error and last_exception:
raise last_exception
return None
def _is_stop_requested(self) -> bool:
"""
Enhanced stop signal checking that handles both local and Redis-based signals.
"""
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
return True
return False
def _wait_with_cancellation_check(self) -> bool:
"""
Wait for rate limiting while aggressively checking for cancellation.
Returns False if cancelled during wait.
"""
current_time = time.time()
time_since_last = current_time - self.rate_limiter.last_request_time
if time_since_last < self.rate_limiter.min_interval:
sleep_time = self.rate_limiter.min_interval - time_since_last
if not self._sleep_with_cancellation_check(sleep_time):
return False
self.rate_limiter.last_request_time = time.time()
return True
def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
"""
Sleep for the specified time while aggressively checking for cancellation.
Args:
sleep_time: Time to sleep in seconds
Returns:
bool: True if sleep completed, False if cancelled
"""
sleep_start = time.time()
check_interval = 0.05 # Check every 50ms for aggressive responsiveness
while time.time() - sleep_start < sleep_time:
if self._is_stop_requested():
return False
remaining_time = sleep_time - (time.time() - sleep_start)
time.sleep(min(check_interval, remaining_time))
return True
def set_stop_event(self, stop_event: threading.Event) -> None:
"""
Set the stop event for this provider to enable cancellation.
Args:
stop_event: Threading event to signal cancellation
"""
self._stop_event = stop_event
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
Determine if a request should be retried based on the exception.
Args:
exception: The request exception that occurred
Returns:
True if the request should be retried
"""
# Retry on connection errors and timeouts
if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)):
return True
if isinstance(exception, requests.exceptions.HTTPError):
if hasattr(exception, 'response') and exception.response:
# Retry on server errors (5xx) AND on rate-limiting errors (429)
return exception.response.status_code >= 500 or exception.response.status_code == 429
return False
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str,
confidence_score: float,
raw_data: Dict[str, Any],
discovery_method: str) -> None:
"""
Log discovery of a new relationship.
Args:
source_node: Source node identifier
target_node: Target node identifier
relationship_type: Type of relationship
confidence_score: Confidence score
raw_data: Raw data from provider
discovery_method: Method used for discovery
"""
self.total_relationships_found += 1
self.logger.log_relationship_discovery(
source_node=source_node,
target_node=target_node,
relationship_type=relationship_type,
confidence_score=confidence_score,
provider=self.name,
raw_data=raw_data,
discovery_method=discovery_method
)
def get_statistics(self) -> Dict[str, Any]:
"""
Get provider statistics including cache performance.
Returns:
Dictionary containing provider performance metrics
"""
return {
'name': self.name,
'total_requests': self.total_requests,
'successful_requests': self.successful_requests,
'failed_requests': self.failed_requests,
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'cache_hit_rate': (self.cache_hits / (self.cache_hits + self.cache_misses) * 100) if (self.cache_hits + self.cache_misses) > 0 else 0
}