it
This commit is contained in:
@@ -5,14 +5,16 @@ import requests
|
||||
import threading
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.logger import get_forensic_logger
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple rate limiter for API calls."""
|
||||
"""Thread-safe rate limiter for API calls."""
|
||||
|
||||
def __init__(self, requests_per_minute: int):
|
||||
"""
|
||||
@@ -24,36 +26,152 @@ class RateLimiter:
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.min_interval = 60.0 / requests_per_minute
|
||||
self.last_request_time = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __getstate__(self):
|
||||
"""RateLimiter is fully picklable, return full state."""
|
||||
return self.__dict__.copy()
|
||||
state = self.__dict__.copy()
|
||||
# Exclude unpickleable lock
|
||||
if '_lock' in state:
|
||||
del state['_lock']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore RateLimiter state."""
|
||||
self.__dict__.update(state)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def wait_if_needed(self) -> None:
|
||||
"""Wait if necessary to respect rate limits."""
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self.last_request_time
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self.last_request_time
|
||||
|
||||
if time_since_last < self.min_interval:
|
||||
sleep_time = self.min_interval - time_since_last
|
||||
time.sleep(sleep_time)
|
||||
if time_since_last < self.min_interval:
|
||||
sleep_time = self.min_interval - time_since_last
|
||||
time.sleep(sleep_time)
|
||||
|
||||
self.last_request_time = time.time()
|
||||
self.last_request_time = time.time()
|
||||
|
||||
|
||||
class ProviderCache:
|
||||
"""Thread-safe global cache for provider queries."""
|
||||
|
||||
def __init__(self, provider_name: str, cache_expiry_hours: int = 12):
|
||||
"""
|
||||
Initialize provider-specific cache.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider for cache directory
|
||||
cache_expiry_hours: Cache expiry time in hours
|
||||
"""
|
||||
self.provider_name = provider_name
|
||||
self.cache_expiry = cache_expiry_hours * 3600 # Convert to seconds
|
||||
self.cache_dir = os.path.join('.cache', provider_name)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Ensure cache directory exists with thread-safe creation
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _generate_cache_key(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> str:
|
||||
"""Generate unique cache key for request."""
|
||||
cache_data = f"{method}:{url}:{json.dumps(params or {}, sort_keys=True)}"
|
||||
return hashlib.md5(cache_data.encode()).hexdigest() + ".json"
|
||||
|
||||
def get_cached_response(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> Optional[requests.Response]:
|
||||
"""
|
||||
Retrieve cached response if available and not expired.
|
||||
|
||||
Returns:
|
||||
Cached Response object or None if cache miss/expired
|
||||
"""
|
||||
cache_key = self._generate_cache_key(method, url, params)
|
||||
cache_path = os.path.join(self.cache_dir, cache_key)
|
||||
|
||||
with self._lock:
|
||||
if not os.path.exists(cache_path):
|
||||
return None
|
||||
|
||||
# Check if cache is expired
|
||||
cache_age = time.time() - os.path.getmtime(cache_path)
|
||||
if cache_age >= self.cache_expiry:
|
||||
try:
|
||||
os.remove(cache_path)
|
||||
except OSError:
|
||||
pass # File might have been removed by another thread
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
cached_data = json.load(f)
|
||||
|
||||
# Reconstruct Response object
|
||||
response = requests.Response()
|
||||
response.status_code = cached_data['status_code']
|
||||
response._content = cached_data['content'].encode('utf-8')
|
||||
response.headers.update(cached_data['headers'])
|
||||
|
||||
return response
|
||||
|
||||
except (json.JSONDecodeError, KeyError, IOError) as e:
|
||||
# Cache file corrupted, remove it
|
||||
try:
|
||||
os.remove(cache_path)
|
||||
except OSError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def cache_response(self, method: str, url: str, params: Optional[Dict[str, Any]],
|
||||
response: requests.Response) -> bool:
|
||||
"""
|
||||
Cache successful response to disk.
|
||||
|
||||
Returns:
|
||||
True if cached successfully, False otherwise
|
||||
"""
|
||||
if response.status_code != 200:
|
||||
return False
|
||||
|
||||
cache_key = self._generate_cache_key(method, url, params)
|
||||
cache_path = os.path.join(self.cache_dir, cache_key)
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
cache_data = {
|
||||
'status_code': response.status_code,
|
||||
'content': response.text,
|
||||
'headers': dict(response.headers),
|
||||
'cached_at': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Write to temporary file first, then rename for atomic operation
|
||||
temp_path = cache_path + '.tmp'
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(cache_data, f)
|
||||
|
||||
# Atomic rename to prevent partial cache files
|
||||
os.rename(temp_path, cache_path)
|
||||
return True
|
||||
|
||||
except (IOError, OSError) as e:
|
||||
# Clean up temp file if it exists
|
||||
try:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for all DNSRecon data providers.
|
||||
Now supports session-specific configuration.
|
||||
Now supports global provider-specific caching and session-specific configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
||||
"""
|
||||
Initialize base provider with session-specific configuration.
|
||||
Initialize base provider with global caching and session-specific configuration.
|
||||
|
||||
Args:
|
||||
name: Provider name for logging
|
||||
@@ -80,28 +198,25 @@ class BaseProvider(ABC):
|
||||
self.logger = get_forensic_logger()
|
||||
self._stop_event = None
|
||||
|
||||
# Caching configuration (per session)
|
||||
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
|
||||
self.cache_expiry = 12 * 3600 # 12 hours in seconds
|
||||
if not os.path.exists(self.cache_dir):
|
||||
os.makedirs(self.cache_dir)
|
||||
# GLOBAL provider-specific caching (not session-based)
|
||||
self.cache = ProviderCache(name, cache_expiry_hours=12)
|
||||
|
||||
# Statistics (per provider instance)
|
||||
self.total_requests = 0
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_relationships_found = 0
|
||||
self.cache_hits = 0
|
||||
self.cache_misses = 0
|
||||
|
||||
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
|
||||
print(f"Initialized {name} provider with global cache and session config (rate: {actual_rate_limit}/min)")
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
||||
state = self.__dict__.copy()
|
||||
# Exclude the unpickleable '_local' attribute and stop event
|
||||
unpicklable_attrs = ['_local', '_stop_event']
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
del state[attr]
|
||||
state['_local'] = None
|
||||
state['_stop_event'] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -116,7 +231,7 @@ class BaseProvider(ABC):
|
||||
if not hasattr(self._local, 'session'):
|
||||
self._local.session = requests.Session()
|
||||
self._local.session.headers.update({
|
||||
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
|
||||
'User-Agent': 'DNSRecon/2.0 (Passive Reconnaissance Tool)'
|
||||
})
|
||||
return self._local.session
|
||||
|
||||
@@ -177,37 +292,28 @@ class BaseProvider(ABC):
|
||||
target_indicator: str = "",
|
||||
max_retries: int = 3) -> Optional[requests.Response]:
|
||||
"""
|
||||
Make a rate-limited HTTP request with aggressive stop signal handling.
|
||||
Terminates immediately when stop is requested, including during retries.
|
||||
Make a rate-limited HTTP request with global caching and aggressive stop signal handling.
|
||||
"""
|
||||
# Check for cancellation before starting
|
||||
if self._is_stop_requested():
|
||||
print(f"Request cancelled before start: {url}")
|
||||
return None
|
||||
|
||||
# Create a unique cache key
|
||||
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
|
||||
cache_path = os.path.join(self.cache_dir, cache_key)
|
||||
|
||||
# Check cache
|
||||
if os.path.exists(cache_path):
|
||||
cache_age = time.time() - os.path.getmtime(cache_path)
|
||||
if cache_age < self.cache_expiry:
|
||||
print(f"Returning cached response for: {url}")
|
||||
with open(cache_path, 'r') as f:
|
||||
cached_data = json.load(f)
|
||||
response = requests.Response()
|
||||
response.status_code = cached_data['status_code']
|
||||
response._content = cached_data['content'].encode('utf-8')
|
||||
response.headers = cached_data['headers']
|
||||
return response
|
||||
# Check global cache first
|
||||
cached_response = self.cache.get_cached_response(method, url, params)
|
||||
if cached_response is not None:
|
||||
print(f"Cache hit for {self.name}: {url}")
|
||||
self.cache_hits += 1
|
||||
return cached_response
|
||||
|
||||
self.cache_misses += 1
|
||||
|
||||
# Determine effective max_retries based on stop signal
|
||||
effective_max_retries = 0 if self._is_stop_requested() else max_retries
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(effective_max_retries + 1):
|
||||
# AGGRESSIVE: Check for cancellation before each attempt
|
||||
# Check for cancellation before each attempt
|
||||
if self._is_stop_requested():
|
||||
print(f"Request cancelled during attempt {attempt + 1}: {url}")
|
||||
return None
|
||||
@@ -217,7 +323,7 @@ class BaseProvider(ABC):
|
||||
print(f"Request cancelled during rate limiting: {url}")
|
||||
return None
|
||||
|
||||
# AGGRESSIVE: Final check before making HTTP request
|
||||
# Final check before making HTTP request
|
||||
if self._is_stop_requested():
|
||||
print(f"Request cancelled before HTTP call: {url}")
|
||||
return None
|
||||
@@ -236,11 +342,8 @@ class BaseProvider(ABC):
|
||||
|
||||
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
|
||||
|
||||
# AGGRESSIVE: Use much shorter timeout if termination is requested
|
||||
request_timeout = self.timeout
|
||||
if self._is_stop_requested():
|
||||
request_timeout = 2 # Max 2 seconds if termination requested
|
||||
print(f"Stop requested - using short timeout: {request_timeout}s")
|
||||
# Use shorter timeout if termination is requested
|
||||
request_timeout = 2 if self._is_stop_requested() else self.timeout
|
||||
|
||||
# Make request
|
||||
if method.upper() == "GET":
|
||||
@@ -276,13 +379,9 @@ class BaseProvider(ABC):
|
||||
error=None,
|
||||
target_indicator=target_indicator
|
||||
)
|
||||
# Cache the successful response to disk
|
||||
with open(cache_path, 'w') as f:
|
||||
json.dump({
|
||||
'status_code': response.status_code,
|
||||
'content': response.text,
|
||||
'headers': dict(response.headers)
|
||||
}, f)
|
||||
|
||||
# Cache the successful response globally
|
||||
self.cache.cache_response(method, url, params, response)
|
||||
return response
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
@@ -291,23 +390,21 @@ class BaseProvider(ABC):
|
||||
print(f"Request failed (attempt {attempt + 1}): {error}")
|
||||
last_exception = e
|
||||
|
||||
# AGGRESSIVE: Immediately abort retries if stop requested
|
||||
# Immediately abort retries if stop requested
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested - aborting retries for: {url}")
|
||||
break
|
||||
|
||||
# Check if we should retry (but only if stop not requested)
|
||||
# Check if we should retry
|
||||
if attempt < effective_max_retries and self._should_retry(e):
|
||||
# Use a longer, more respectful backoff for 429 errors
|
||||
# Exponential backoff with jitter for 429 errors
|
||||
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
|
||||
# Start with a 10-second backoff and increase exponentially
|
||||
backoff_time = 10 * (2 ** attempt)
|
||||
backoff_time = min(60, 10 * (2 ** attempt))
|
||||
print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
|
||||
else:
|
||||
backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors
|
||||
backoff_time = min(2.0, (2 ** attempt) * 0.5)
|
||||
print(f"Retrying in {backoff_time} seconds...")
|
||||
|
||||
# AGGRESSIVE: Much shorter backoff and more frequent checking
|
||||
if not self._sleep_with_cancellation_check(backoff_time):
|
||||
print(f"Stop requested during backoff - aborting: {url}")
|
||||
return None
|
||||
@@ -348,7 +445,6 @@ class BaseProvider(ABC):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _wait_with_cancellation_check(self) -> bool:
|
||||
"""
|
||||
Wait for rate limiting while aggressively checking for cancellation.
|
||||
@@ -447,7 +543,7 @@ class BaseProvider(ABC):
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get provider statistics.
|
||||
Get provider statistics including cache performance.
|
||||
|
||||
Returns:
|
||||
Dictionary containing provider performance metrics
|
||||
@@ -459,5 +555,8 @@ class BaseProvider(ABC):
|
||||
'failed_requests': self.failed_requests,
|
||||
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
||||
'relationships_found': self.total_relationships_found,
|
||||
'rate_limit': self.rate_limiter.requests_per_minute
|
||||
'rate_limit': self.rate_limiter.requests_per_minute,
|
||||
'cache_hits': self.cache_hits,
|
||||
'cache_misses': self.cache_misses,
|
||||
'cache_hit_rate': (self.cache_hits / (self.cache_hits + self.cache_misses) * 100) if (self.cache_hits + self.cache_misses) > 0 else 0
|
||||
}
|
||||
Reference in New Issue
Block a user