dnsrecon/providers/base_provider.py
overcuriousity b7a57f1552 it
2025-09-13 23:45:36 +02:00

463 lines
17 KiB
Python

# dnsrecon/providers/base_provider.py
import time
import requests
import threading
import os
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from core.logger import get_forensic_logger
class RateLimiter:
"""Simple rate limiter for API calls."""
def __init__(self, requests_per_minute: int):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests allowed per minute
"""
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request_time = 0
def __getstate__(self):
"""RateLimiter is fully picklable, return full state."""
return self.__dict__.copy()
def __setstate__(self, state):
"""Restore RateLimiter state."""
self.__dict__.update(state)
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Now supports session-specific configuration.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
"""
Initialize base provider with session-specific configuration.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit (default override)
timeout: Request timeout in seconds
session_config: Session-specific configuration
"""
# Use session config if provided, otherwise fall back to global config
if session_config is not None:
self.config = session_config
actual_rate_limit = self.config.get_rate_limit(name)
actual_timeout = self.config.default_timeout
else:
# Fallback to global config for backwards compatibility
from config import config as global_config
self.config = global_config
actual_rate_limit = rate_limit
actual_timeout = timeout
self.name = name
self.rate_limiter = RateLimiter(actual_rate_limit)
self.timeout = actual_timeout
self._local = threading.local()
self.logger = get_forensic_logger()
self._stop_event = None
# Caching configuration (per session)
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
self.cache_expiry = 12 * 3600 # 12 hours in seconds
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# Statistics (per provider instance)
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
unpicklable_attrs = ['_local', '_stop_event']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
return state
def __setstate__(self, state):
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# Re-initialize the '_local' attribute and stop event
self._local = threading.local()
self._stop_event = None
@property
def session(self):
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
})
return self._local.session
@abstractmethod
def get_name(self) -> str:
"""Return the provider name."""
pass
@abstractmethod
def get_display_name(self) -> str:
"""Return the provider display name for the UI."""
pass
@abstractmethod
def requires_api_key(self) -> bool:
"""Return True if the provider requires an API key."""
pass
@abstractmethod
def get_eligibility(self) -> Dict[str, bool]:
"""Return a dictionary indicating if the provider can query domains and/or IPs."""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if the provider is available and properly configured."""
pass
@abstractmethod
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query the provider for information about a domain.
Args:
domain: Domain to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
@abstractmethod
def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query the provider for information about an IP address.
Args:
ip: IP address to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with aggressive stop signal handling.
Terminates immediately when stop is requested, including during retries.
"""
# Check for cancellation before starting
if self._is_stop_requested():
print(f"Request cancelled before start: {url}")
return None
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
# Check cache
if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry:
print(f"Returning cached response for: {url}")
with open(cache_path, 'r') as f:
cached_data = json.load(f)
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers = cached_data['headers']
return response
# Determine effective max_retries based on stop signal
effective_max_retries = 0 if self._is_stop_requested() else max_retries
last_exception = None
for attempt in range(effective_max_retries + 1):
# AGGRESSIVE: Check for cancellation before each attempt
if self._is_stop_requested():
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
# Apply rate limiting with cancellation awareness
if not self._wait_with_cancellation_check():
print(f"Request cancelled during rate limiting: {url}")
return None
# AGGRESSIVE: Final check before making HTTP request
if self._is_stop_requested():
print(f"Request cancelled before HTTP call: {url}")
return None
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
# Prepare request
request_headers = self.session.headers.copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# AGGRESSIVE: Use much shorter timeout if termination is requested
request_timeout = self.timeout
if self._is_stop_requested():
request_timeout = 2 # Max 2 seconds if termination requested
print(f"Stop requested - using short timeout: {request_timeout}s")
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=request_timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=request_timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
# Cache the successful response to disk
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
return response
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
last_exception = e
# AGGRESSIVE: Immediately abort retries if stop requested
if self._is_stop_requested():
print(f"Stop requested - aborting retries for: {url}")
break
# Check if we should retry (but only if stop not requested)
if attempt < effective_max_retries and self._should_retry(e):
# Use a longer, more respectful backoff for 429 errors
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
# Start with a 10-second backoff and increase exponentially
backoff_time = 10 * (2 ** attempt)
print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
else:
backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors
print(f"Retrying in {backoff_time} seconds...")
# AGGRESSIVE: Much shorter backoff and more frequent checking
if not self._sleep_with_cancellation_check(backoff_time):
print(f"Stop requested during backoff - aborting: {url}")
return None
continue
else:
break
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
last_exception = e
break
# All attempts failed - log and return None
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
if error and last_exception:
raise last_exception
return None
def _is_stop_requested(self) -> bool:
"""
Enhanced stop signal checking that handles both local and Redis-based signals.
"""
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
return True
return False
def _wait_with_cancellation_check(self) -> bool:
"""
Wait for rate limiting while aggressively checking for cancellation.
Returns False if cancelled during wait.
"""
current_time = time.time()
time_since_last = current_time - self.rate_limiter.last_request_time
if time_since_last < self.rate_limiter.min_interval:
sleep_time = self.rate_limiter.min_interval - time_since_last
if not self._sleep_with_cancellation_check(sleep_time):
return False
self.rate_limiter.last_request_time = time.time()
return True
def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
"""
Sleep for the specified time while aggressively checking for cancellation.
Args:
sleep_time: Time to sleep in seconds
Returns:
bool: True if sleep completed, False if cancelled
"""
sleep_start = time.time()
check_interval = 0.05 # Check every 50ms for aggressive responsiveness
while time.time() - sleep_start < sleep_time:
if self._is_stop_requested():
return False
remaining_time = sleep_time - (time.time() - sleep_start)
time.sleep(min(check_interval, remaining_time))
return True
def set_stop_event(self, stop_event: threading.Event) -> None:
"""
Set the stop event for this provider to enable cancellation.
Args:
stop_event: Threading event to signal cancellation
"""
self._stop_event = stop_event
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
Determine if a request should be retried based on the exception.
Args:
exception: The request exception that occurred
Returns:
True if the request should be retried
"""
# Retry on connection errors and timeouts
if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)):
return True
if isinstance(exception, requests.exceptions.HTTPError):
if hasattr(exception, 'response') and exception.response:
# Retry on server errors (5xx) AND on rate-limiting errors (429)
return exception.response.status_code >= 500 or exception.response.status_code == 429
return False
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str,
confidence_score: float,
raw_data: Dict[str, Any],
discovery_method: str) -> None:
"""
Log discovery of a new relationship.
Args:
source_node: Source node identifier
target_node: Target node identifier
relationship_type: Type of relationship
confidence_score: Confidence score
raw_data: Raw data from provider
discovery_method: Method used for discovery
"""
self.total_relationships_found += 1
self.logger.log_relationship_discovery(
source_node=source_node,
target_node=target_node,
relationship_type=relationship_type,
confidence_score=confidence_score,
provider=self.name,
raw_data=raw_data,
discovery_method=discovery_method
)
def get_statistics(self) -> Dict[str, Any]:
"""
Get provider statistics.
Returns:
Dictionary containing provider performance metrics
"""
return {
'name': self.name,
'total_requests': self.total_requests,
'successful_requests': self.successful_requests,
'failed_requests': self.failed_requests,
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
}