This commit is contained in:
overcuriousity
2025-09-14 13:14:02 +02:00
parent b26002eff9
commit 4378146d0c
8 changed files with 1765 additions and 683 deletions

View File

@@ -5,14 +5,16 @@ import requests
import threading
import os
import json
import hashlib
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime, timezone
from core.logger import get_forensic_logger
class RateLimiter:
"""Simple rate limiter for API calls."""
"""Thread-safe rate limiter for API calls."""
def __init__(self, requests_per_minute: int):
"""
@@ -24,36 +26,152 @@ class RateLimiter:
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request_time = 0
self._lock = threading.Lock()
def __getstate__(self):
"""RateLimiter is fully picklable, return full state."""
return self.__dict__.copy()
state = self.__dict__.copy()
# Exclude unpickleable lock
if '_lock' in state:
del state['_lock']
return state
def __setstate__(self, state):
"""Restore RateLimiter state."""
self.__dict__.update(state)
self._lock = threading.Lock()
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
with self._lock:
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
self.last_request_time = time.time()
class ProviderCache:
"""Thread-safe global cache for provider queries."""
def __init__(self, provider_name: str, cache_expiry_hours: int = 12):
"""
Initialize provider-specific cache.
Args:
provider_name: Name of the provider for cache directory
cache_expiry_hours: Cache expiry time in hours
"""
self.provider_name = provider_name
self.cache_expiry = cache_expiry_hours * 3600 # Convert to seconds
self.cache_dir = os.path.join('.cache', provider_name)
self._lock = threading.Lock()
# Ensure cache directory exists with thread-safe creation
os.makedirs(self.cache_dir, exist_ok=True)
def _generate_cache_key(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> str:
"""Generate unique cache key for request."""
cache_data = f"{method}:{url}:{json.dumps(params or {}, sort_keys=True)}"
return hashlib.md5(cache_data.encode()).hexdigest() + ".json"
def get_cached_response(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> Optional[requests.Response]:
"""
Retrieve cached response if available and not expired.
Returns:
Cached Response object or None if cache miss/expired
"""
cache_key = self._generate_cache_key(method, url, params)
cache_path = os.path.join(self.cache_dir, cache_key)
with self._lock:
if not os.path.exists(cache_path):
return None
# Check if cache is expired
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age >= self.cache_expiry:
try:
os.remove(cache_path)
except OSError:
pass # File might have been removed by another thread
return None
try:
with open(cache_path, 'r', encoding='utf-8') as f:
cached_data = json.load(f)
# Reconstruct Response object
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers.update(cached_data['headers'])
return response
except (json.JSONDecodeError, KeyError, IOError) as e:
# Cache file corrupted, remove it
try:
os.remove(cache_path)
except OSError:
pass
return None
def cache_response(self, method: str, url: str, params: Optional[Dict[str, Any]],
response: requests.Response) -> bool:
"""
Cache successful response to disk.
Returns:
True if cached successfully, False otherwise
"""
if response.status_code != 200:
return False
cache_key = self._generate_cache_key(method, url, params)
cache_path = os.path.join(self.cache_dir, cache_key)
with self._lock:
try:
cache_data = {
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers),
'cached_at': datetime.now(timezone.utc).isoformat()
}
# Write to temporary file first, then rename for atomic operation
temp_path = cache_path + '.tmp'
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(cache_data, f)
# Atomic rename to prevent partial cache files
os.rename(temp_path, cache_path)
return True
except (IOError, OSError) as e:
# Clean up temp file if it exists
try:
if os.path.exists(temp_path):
os.remove(temp_path)
except OSError:
pass
return False
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Now supports session-specific configuration.
Now supports global provider-specific caching and session-specific configuration.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
"""
Initialize base provider with session-specific configuration.
Initialize base provider with global caching and session-specific configuration.
Args:
name: Provider name for logging
@@ -80,28 +198,25 @@ class BaseProvider(ABC):
self.logger = get_forensic_logger()
self._stop_event = None
# Caching configuration (per session)
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
self.cache_expiry = 12 * 3600 # 12 hours in seconds
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# GLOBAL provider-specific caching (not session-based)
self.cache = ProviderCache(name, cache_expiry_hours=12)
# Statistics (per provider instance)
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
self.cache_hits = 0
self.cache_misses = 0
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
print(f"Initialized {name} provider with global cache and session config (rate: {actual_rate_limit}/min)")
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
unpicklable_attrs = ['_local', '_stop_event']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
state['_local'] = None
state['_stop_event'] = None
return state
def __setstate__(self, state):
@@ -116,7 +231,7 @@ class BaseProvider(ABC):
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
'User-Agent': 'DNSRecon/2.0 (Passive Reconnaissance Tool)'
})
return self._local.session
@@ -177,37 +292,28 @@ class BaseProvider(ABC):
target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with aggressive stop signal handling.
Terminates immediately when stop is requested, including during retries.
Make a rate-limited HTTP request with global caching and aggressive stop signal handling.
"""
# Check for cancellation before starting
if self._is_stop_requested():
print(f"Request cancelled before start: {url}")
return None
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
# Check cache
if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry:
print(f"Returning cached response for: {url}")
with open(cache_path, 'r') as f:
cached_data = json.load(f)
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers = cached_data['headers']
return response
# Check global cache first
cached_response = self.cache.get_cached_response(method, url, params)
if cached_response is not None:
print(f"Cache hit for {self.name}: {url}")
self.cache_hits += 1
return cached_response
self.cache_misses += 1
# Determine effective max_retries based on stop signal
effective_max_retries = 0 if self._is_stop_requested() else max_retries
last_exception = None
for attempt in range(effective_max_retries + 1):
# AGGRESSIVE: Check for cancellation before each attempt
# Check for cancellation before each attempt
if self._is_stop_requested():
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
@@ -217,7 +323,7 @@ class BaseProvider(ABC):
print(f"Request cancelled during rate limiting: {url}")
return None
# AGGRESSIVE: Final check before making HTTP request
# Final check before making HTTP request
if self._is_stop_requested():
print(f"Request cancelled before HTTP call: {url}")
return None
@@ -236,11 +342,8 @@ class BaseProvider(ABC):
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# AGGRESSIVE: Use much shorter timeout if termination is requested
request_timeout = self.timeout
if self._is_stop_requested():
request_timeout = 2 # Max 2 seconds if termination requested
print(f"Stop requested - using short timeout: {request_timeout}s")
# Use shorter timeout if termination is requested
request_timeout = 2 if self._is_stop_requested() else self.timeout
# Make request
if method.upper() == "GET":
@@ -276,13 +379,9 @@ class BaseProvider(ABC):
error=None,
target_indicator=target_indicator
)
# Cache the successful response to disk
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
# Cache the successful response globally
self.cache.cache_response(method, url, params, response)
return response
except requests.exceptions.RequestException as e:
@@ -291,23 +390,21 @@ class BaseProvider(ABC):
print(f"Request failed (attempt {attempt + 1}): {error}")
last_exception = e
# AGGRESSIVE: Immediately abort retries if stop requested
# Immediately abort retries if stop requested
if self._is_stop_requested():
print(f"Stop requested - aborting retries for: {url}")
break
# Check if we should retry (but only if stop not requested)
# Check if we should retry
if attempt < effective_max_retries and self._should_retry(e):
# Use a longer, more respectful backoff for 429 errors
# Exponential backoff with jitter for 429 errors
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
# Start with a 10-second backoff and increase exponentially
backoff_time = 10 * (2 ** attempt)
backoff_time = min(60, 10 * (2 ** attempt))
print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
else:
backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors
backoff_time = min(2.0, (2 ** attempt) * 0.5)
print(f"Retrying in {backoff_time} seconds...")
# AGGRESSIVE: Much shorter backoff and more frequent checking
if not self._sleep_with_cancellation_check(backoff_time):
print(f"Stop requested during backoff - aborting: {url}")
return None
@@ -348,7 +445,6 @@ class BaseProvider(ABC):
return True
return False
def _wait_with_cancellation_check(self) -> bool:
"""
Wait for rate limiting while aggressively checking for cancellation.
@@ -447,7 +543,7 @@ class BaseProvider(ABC):
def get_statistics(self) -> Dict[str, Any]:
"""
Get provider statistics.
Get provider statistics including cache performance.
Returns:
Dictionary containing provider performance metrics
@@ -459,5 +555,8 @@ class BaseProvider(ABC):
'failed_requests': self.failed_requests,
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
'rate_limit': self.rate_limiter.requests_per_minute,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'cache_hit_rate': (self.cache_hits / (self.cache_hits + self.cache_misses) * 100) if (self.cache_hits + self.cache_misses) > 0 else 0
}