562 lines
21 KiB
Python
562 lines
21 KiB
Python
# dnsrecon/providers/base_provider.py
|
|
|
|
import time
|
|
import requests
|
|
import threading
|
|
import os
|
|
import json
|
|
import hashlib
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from datetime import datetime, timezone
|
|
|
|
from core.logger import get_forensic_logger
|
|
|
|
|
|
class RateLimiter:
|
|
"""Thread-safe rate limiter for API calls."""
|
|
|
|
def __init__(self, requests_per_minute: int):
|
|
"""
|
|
Initialize rate limiter.
|
|
|
|
Args:
|
|
requests_per_minute: Maximum requests allowed per minute
|
|
"""
|
|
self.requests_per_minute = requests_per_minute
|
|
self.min_interval = 60.0 / requests_per_minute
|
|
self.last_request_time = 0
|
|
self._lock = threading.Lock()
|
|
|
|
def __getstate__(self):
|
|
"""RateLimiter is fully picklable, return full state."""
|
|
state = self.__dict__.copy()
|
|
# Exclude unpickleable lock
|
|
if '_lock' in state:
|
|
del state['_lock']
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore RateLimiter state."""
|
|
self.__dict__.update(state)
|
|
self._lock = threading.Lock()
|
|
|
|
def wait_if_needed(self) -> None:
|
|
"""Wait if necessary to respect rate limits."""
|
|
with self._lock:
|
|
current_time = time.time()
|
|
time_since_last = current_time - self.last_request_time
|
|
|
|
if time_since_last < self.min_interval:
|
|
sleep_time = self.min_interval - time_since_last
|
|
time.sleep(sleep_time)
|
|
|
|
self.last_request_time = time.time()
|
|
|
|
|
|
class ProviderCache:
|
|
"""Thread-safe global cache for provider queries."""
|
|
|
|
def __init__(self, provider_name: str, cache_expiry_hours: int = 12):
|
|
"""
|
|
Initialize provider-specific cache.
|
|
|
|
Args:
|
|
provider_name: Name of the provider for cache directory
|
|
cache_expiry_hours: Cache expiry time in hours
|
|
"""
|
|
self.provider_name = provider_name
|
|
self.cache_expiry = cache_expiry_hours * 3600 # Convert to seconds
|
|
self.cache_dir = os.path.join('.cache', provider_name)
|
|
self._lock = threading.Lock()
|
|
|
|
# Ensure cache directory exists with thread-safe creation
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
|
|
def _generate_cache_key(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> str:
|
|
"""Generate unique cache key for request."""
|
|
cache_data = f"{method}:{url}:{json.dumps(params or {}, sort_keys=True)}"
|
|
return hashlib.md5(cache_data.encode()).hexdigest() + ".json"
|
|
|
|
def get_cached_response(self, method: str, url: str, params: Optional[Dict[str, Any]]) -> Optional[requests.Response]:
|
|
"""
|
|
Retrieve cached response if available and not expired.
|
|
|
|
Returns:
|
|
Cached Response object or None if cache miss/expired
|
|
"""
|
|
cache_key = self._generate_cache_key(method, url, params)
|
|
cache_path = os.path.join(self.cache_dir, cache_key)
|
|
|
|
with self._lock:
|
|
if not os.path.exists(cache_path):
|
|
return None
|
|
|
|
# Check if cache is expired
|
|
cache_age = time.time() - os.path.getmtime(cache_path)
|
|
if cache_age >= self.cache_expiry:
|
|
try:
|
|
os.remove(cache_path)
|
|
except OSError:
|
|
pass # File might have been removed by another thread
|
|
return None
|
|
|
|
try:
|
|
with open(cache_path, 'r', encoding='utf-8') as f:
|
|
cached_data = json.load(f)
|
|
|
|
# Reconstruct Response object
|
|
response = requests.Response()
|
|
response.status_code = cached_data['status_code']
|
|
response._content = cached_data['content'].encode('utf-8')
|
|
response.headers.update(cached_data['headers'])
|
|
|
|
return response
|
|
|
|
except (json.JSONDecodeError, KeyError, IOError) as e:
|
|
# Cache file corrupted, remove it
|
|
try:
|
|
os.remove(cache_path)
|
|
except OSError:
|
|
pass
|
|
return None
|
|
|
|
def cache_response(self, method: str, url: str, params: Optional[Dict[str, Any]],
|
|
response: requests.Response) -> bool:
|
|
"""
|
|
Cache successful response to disk.
|
|
|
|
Returns:
|
|
True if cached successfully, False otherwise
|
|
"""
|
|
if response.status_code != 200:
|
|
return False
|
|
|
|
cache_key = self._generate_cache_key(method, url, params)
|
|
cache_path = os.path.join(self.cache_dir, cache_key)
|
|
|
|
with self._lock:
|
|
try:
|
|
cache_data = {
|
|
'status_code': response.status_code,
|
|
'content': response.text,
|
|
'headers': dict(response.headers),
|
|
'cached_at': datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
# Write to temporary file first, then rename for atomic operation
|
|
temp_path = cache_path + '.tmp'
|
|
with open(temp_path, 'w', encoding='utf-8') as f:
|
|
json.dump(cache_data, f)
|
|
|
|
# Atomic rename to prevent partial cache files
|
|
os.rename(temp_path, cache_path)
|
|
return True
|
|
|
|
except (IOError, OSError) as e:
|
|
# Clean up temp file if it exists
|
|
try:
|
|
if os.path.exists(temp_path):
|
|
os.remove(temp_path)
|
|
except OSError:
|
|
pass
|
|
return False
|
|
|
|
|
|
class BaseProvider(ABC):
|
|
"""
|
|
Abstract base class for all DNSRecon data providers.
|
|
Now supports global provider-specific caching and session-specific configuration.
|
|
"""
|
|
|
|
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
|
"""
|
|
Initialize base provider with global caching and session-specific configuration.
|
|
|
|
Args:
|
|
name: Provider name for logging
|
|
rate_limit: Requests per minute limit (default override)
|
|
timeout: Request timeout in seconds
|
|
session_config: Session-specific configuration
|
|
"""
|
|
# Use session config if provided, otherwise fall back to global config
|
|
if session_config is not None:
|
|
self.config = session_config
|
|
actual_rate_limit = self.config.get_rate_limit(name)
|
|
actual_timeout = self.config.default_timeout
|
|
else:
|
|
# Fallback to global config for backwards compatibility
|
|
from config import config as global_config
|
|
self.config = global_config
|
|
actual_rate_limit = rate_limit
|
|
actual_timeout = timeout
|
|
|
|
self.name = name
|
|
self.rate_limiter = RateLimiter(actual_rate_limit)
|
|
self.timeout = actual_timeout
|
|
self._local = threading.local()
|
|
self.logger = get_forensic_logger()
|
|
self._stop_event = None
|
|
|
|
# GLOBAL provider-specific caching (not session-based)
|
|
self.cache = ProviderCache(name, cache_expiry_hours=12)
|
|
|
|
# Statistics (per provider instance)
|
|
self.total_requests = 0
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.total_relationships_found = 0
|
|
self.cache_hits = 0
|
|
self.cache_misses = 0
|
|
|
|
print(f"Initialized {name} provider with global cache and session config (rate: {actual_rate_limit}/min)")
|
|
|
|
def __getstate__(self):
|
|
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
|
state = self.__dict__.copy()
|
|
# Exclude the unpickleable '_local' attribute and stop event
|
|
state['_local'] = None
|
|
state['_stop_event'] = None
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
# Re-initialize the '_local' attribute and stop event
|
|
self._local = threading.local()
|
|
self._stop_event = None
|
|
|
|
@property
|
|
def session(self):
|
|
if not hasattr(self._local, 'session'):
|
|
self._local.session = requests.Session()
|
|
self._local.session.headers.update({
|
|
'User-Agent': 'DNSRecon/2.0 (Passive Reconnaissance Tool)'
|
|
})
|
|
return self._local.session
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str:
|
|
"""Return the provider name."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_display_name(self) -> str:
|
|
"""Return the provider display name for the UI."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def requires_api_key(self) -> bool:
|
|
"""Return True if the provider requires an API key."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_eligibility(self) -> Dict[str, bool]:
|
|
"""Return a dictionary indicating if the provider can query domains and/or IPs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_available(self) -> bool:
|
|
"""Check if the provider is available and properly configured."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about a domain.
|
|
|
|
Args:
|
|
domain: Domain to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about an IP address.
|
|
|
|
Args:
|
|
ip: IP address to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
def make_request(self, url: str, method: str = "GET",
|
|
params: Optional[Dict[str, Any]] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
target_indicator: str = "",
|
|
max_retries: int = 3) -> Optional[requests.Response]:
|
|
"""
|
|
Make a rate-limited HTTP request with global caching and aggressive stop signal handling.
|
|
"""
|
|
# Check for cancellation before starting
|
|
if self._is_stop_requested():
|
|
print(f"Request cancelled before start: {url}")
|
|
return None
|
|
|
|
# Check global cache first
|
|
cached_response = self.cache.get_cached_response(method, url, params)
|
|
if cached_response is not None:
|
|
print(f"Cache hit for {self.name}: {url}")
|
|
self.cache_hits += 1
|
|
return cached_response
|
|
|
|
self.cache_misses += 1
|
|
|
|
# Determine effective max_retries based on stop signal
|
|
effective_max_retries = 0 if self._is_stop_requested() else max_retries
|
|
last_exception = None
|
|
|
|
for attempt in range(effective_max_retries + 1):
|
|
# Check for cancellation before each attempt
|
|
if self._is_stop_requested():
|
|
print(f"Request cancelled during attempt {attempt + 1}: {url}")
|
|
return None
|
|
|
|
# Apply rate limiting with cancellation awareness
|
|
if not self._wait_with_cancellation_check():
|
|
print(f"Request cancelled during rate limiting: {url}")
|
|
return None
|
|
|
|
# Final check before making HTTP request
|
|
if self._is_stop_requested():
|
|
print(f"Request cancelled before HTTP call: {url}")
|
|
return None
|
|
|
|
start_time = time.time()
|
|
response = None
|
|
error = None
|
|
|
|
try:
|
|
self.total_requests += 1
|
|
|
|
# Prepare request
|
|
request_headers = self.session.headers.copy()
|
|
if headers:
|
|
request_headers.update(headers)
|
|
|
|
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
|
|
|
|
# Use shorter timeout if termination is requested
|
|
request_timeout = 2 if self._is_stop_requested() else self.timeout
|
|
|
|
# Make request
|
|
if method.upper() == "GET":
|
|
response = self.session.get(
|
|
url,
|
|
params=params,
|
|
headers=request_headers,
|
|
timeout=request_timeout
|
|
)
|
|
elif method.upper() == "POST":
|
|
response = self.session.post(
|
|
url,
|
|
json=params,
|
|
headers=request_headers,
|
|
timeout=request_timeout
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
|
|
print(f"Response status: {response.status_code}")
|
|
response.raise_for_status()
|
|
self.successful_requests += 1
|
|
|
|
# Success - log, cache, and return
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code,
|
|
response_size=len(response.content),
|
|
duration_ms=duration_ms,
|
|
error=None,
|
|
target_indicator=target_indicator
|
|
)
|
|
|
|
# Cache the successful response globally
|
|
self.cache.cache_response(method, url, params, response)
|
|
return response
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
error = str(e)
|
|
self.failed_requests += 1
|
|
print(f"Request failed (attempt {attempt + 1}): {error}")
|
|
last_exception = e
|
|
|
|
# Immediately abort retries if stop requested
|
|
if self._is_stop_requested():
|
|
print(f"Stop requested - aborting retries for: {url}")
|
|
break
|
|
|
|
# Check if we should retry
|
|
if attempt < effective_max_retries and self._should_retry(e):
|
|
# Exponential backoff with jitter for 429 errors
|
|
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
|
|
backoff_time = min(60, 10 * (2 ** attempt))
|
|
print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
|
|
else:
|
|
backoff_time = min(2.0, (2 ** attempt) * 0.5)
|
|
print(f"Retrying in {backoff_time} seconds...")
|
|
|
|
if not self._sleep_with_cancellation_check(backoff_time):
|
|
print(f"Stop requested during backoff - aborting: {url}")
|
|
return None
|
|
continue
|
|
else:
|
|
break
|
|
|
|
except Exception as e:
|
|
error = f"Unexpected error: {str(e)}"
|
|
self.failed_requests += 1
|
|
print(f"Unexpected error: {error}")
|
|
last_exception = e
|
|
break
|
|
|
|
# All attempts failed - log and return None
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code if response else None,
|
|
response_size=len(response.content) if response else None,
|
|
duration_ms=duration_ms,
|
|
error=error,
|
|
target_indicator=target_indicator
|
|
)
|
|
|
|
if error and last_exception:
|
|
raise last_exception
|
|
|
|
return None
|
|
|
|
def _is_stop_requested(self) -> bool:
|
|
"""
|
|
Enhanced stop signal checking that handles both local and Redis-based signals.
|
|
"""
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
return True
|
|
return False
|
|
|
|
def _wait_with_cancellation_check(self) -> bool:
|
|
"""
|
|
Wait for rate limiting while aggressively checking for cancellation.
|
|
Returns False if cancelled during wait.
|
|
"""
|
|
current_time = time.time()
|
|
time_since_last = current_time - self.rate_limiter.last_request_time
|
|
|
|
if time_since_last < self.rate_limiter.min_interval:
|
|
sleep_time = self.rate_limiter.min_interval - time_since_last
|
|
if not self._sleep_with_cancellation_check(sleep_time):
|
|
return False
|
|
|
|
self.rate_limiter.last_request_time = time.time()
|
|
return True
|
|
|
|
def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
|
|
"""
|
|
Sleep for the specified time while aggressively checking for cancellation.
|
|
|
|
Args:
|
|
sleep_time: Time to sleep in seconds
|
|
|
|
Returns:
|
|
bool: True if sleep completed, False if cancelled
|
|
"""
|
|
sleep_start = time.time()
|
|
check_interval = 0.05 # Check every 50ms for aggressive responsiveness
|
|
|
|
while time.time() - sleep_start < sleep_time:
|
|
if self._is_stop_requested():
|
|
return False
|
|
remaining_time = sleep_time - (time.time() - sleep_start)
|
|
time.sleep(min(check_interval, remaining_time))
|
|
|
|
return True
|
|
|
|
def set_stop_event(self, stop_event: threading.Event) -> None:
|
|
"""
|
|
Set the stop event for this provider to enable cancellation.
|
|
|
|
Args:
|
|
stop_event: Threading event to signal cancellation
|
|
"""
|
|
self._stop_event = stop_event
|
|
|
|
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
|
|
"""
|
|
Determine if a request should be retried based on the exception.
|
|
|
|
Args:
|
|
exception: The request exception that occurred
|
|
|
|
Returns:
|
|
True if the request should be retried
|
|
"""
|
|
# Retry on connection errors and timeouts
|
|
if isinstance(exception, (requests.exceptions.ConnectionError,
|
|
requests.exceptions.Timeout)):
|
|
return True
|
|
|
|
if isinstance(exception, requests.exceptions.HTTPError):
|
|
if hasattr(exception, 'response') and exception.response:
|
|
# Retry on server errors (5xx) AND on rate-limiting errors (429)
|
|
return exception.response.status_code >= 500 or exception.response.status_code == 429
|
|
|
|
return False
|
|
|
|
def log_relationship_discovery(self, source_node: str, target_node: str,
|
|
relationship_type: str,
|
|
confidence_score: float,
|
|
raw_data: Dict[str, Any],
|
|
discovery_method: str) -> None:
|
|
"""
|
|
Log discovery of a new relationship.
|
|
|
|
Args:
|
|
source_node: Source node identifier
|
|
target_node: Target node identifier
|
|
relationship_type: Type of relationship
|
|
confidence_score: Confidence score
|
|
raw_data: Raw data from provider
|
|
discovery_method: Method used for discovery
|
|
"""
|
|
self.total_relationships_found += 1
|
|
|
|
self.logger.log_relationship_discovery(
|
|
source_node=source_node,
|
|
target_node=target_node,
|
|
relationship_type=relationship_type,
|
|
confidence_score=confidence_score,
|
|
provider=self.name,
|
|
raw_data=raw_data,
|
|
discovery_method=discovery_method
|
|
)
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get provider statistics including cache performance.
|
|
|
|
Returns:
|
|
Dictionary containing provider performance metrics
|
|
"""
|
|
return {
|
|
'name': self.name,
|
|
'total_requests': self.total_requests,
|
|
'successful_requests': self.successful_requests,
|
|
'failed_requests': self.failed_requests,
|
|
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
|
'relationships_found': self.total_relationships_found,
|
|
'rate_limit': self.rate_limiter.requests_per_minute,
|
|
'cache_hits': self.cache_hits,
|
|
'cache_misses': self.cache_misses,
|
|
'cache_hit_rate': (self.cache_hits / (self.cache_hits + self.cache_misses) * 100) if (self.cache_hits + self.cache_misses) > 0 else 0
|
|
} |