395 lines
15 KiB
Python
395 lines
15 KiB
Python
# dnsrecon/providers/base_provider.py
|
|
|
|
import time
|
|
import requests
|
|
import threading
|
|
import os
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
from core.logger import get_forensic_logger
|
|
from core.graph_manager import RelationshipType
|
|
|
|
|
|
class RateLimiter:
|
|
"""Simple rate limiter for API calls."""
|
|
|
|
def __init__(self, requests_per_minute: int):
|
|
"""
|
|
Initialize rate limiter.
|
|
|
|
Args:
|
|
requests_per_minute: Maximum requests allowed per minute
|
|
"""
|
|
self.requests_per_minute = requests_per_minute
|
|
self.min_interval = 60.0 / requests_per_minute
|
|
self.last_request_time = 0
|
|
|
|
def __getstate__(self):
|
|
"""RateLimiter is fully picklable, return full state."""
|
|
return self.__dict__.copy()
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore RateLimiter state."""
|
|
self.__dict__.update(state)
|
|
|
|
def wait_if_needed(self) -> None:
|
|
"""Wait if necessary to respect rate limits."""
|
|
current_time = time.time()
|
|
time_since_last = current_time - self.last_request_time
|
|
|
|
if time_since_last < self.min_interval:
|
|
sleep_time = self.min_interval - time_since_last
|
|
time.sleep(sleep_time)
|
|
|
|
self.last_request_time = time.time()
|
|
|
|
|
|
class BaseProvider(ABC):
|
|
"""
|
|
Abstract base class for all DNSRecon data providers.
|
|
Now supports session-specific configuration.
|
|
"""
|
|
|
|
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
|
"""
|
|
Initialize base provider with session-specific configuration.
|
|
|
|
Args:
|
|
name: Provider name for logging
|
|
rate_limit: Requests per minute limit (default override)
|
|
timeout: Request timeout in seconds
|
|
session_config: Session-specific configuration
|
|
"""
|
|
# Use session config if provided, otherwise fall back to global config
|
|
if session_config is not None:
|
|
self.config = session_config
|
|
actual_rate_limit = self.config.get_rate_limit(name)
|
|
actual_timeout = self.config.default_timeout
|
|
else:
|
|
# Fallback to global config for backwards compatibility
|
|
from config import config as global_config
|
|
self.config = global_config
|
|
actual_rate_limit = rate_limit
|
|
actual_timeout = timeout
|
|
|
|
self.name = name
|
|
self.rate_limiter = RateLimiter(actual_rate_limit)
|
|
self.timeout = actual_timeout
|
|
self._local = threading.local()
|
|
self.logger = get_forensic_logger()
|
|
self._stop_event = None
|
|
|
|
# Caching configuration (per session)
|
|
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
|
|
self.cache_expiry = 12 * 3600 # 12 hours in seconds
|
|
if not os.path.exists(self.cache_dir):
|
|
os.makedirs(self.cache_dir)
|
|
|
|
# Statistics (per provider instance)
|
|
self.total_requests = 0
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.total_relationships_found = 0
|
|
|
|
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
|
|
|
|
def __getstate__(self):
|
|
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
|
state = self.__dict__.copy()
|
|
# Exclude the unpickleable '_local' attribute and stop event
|
|
unpicklable_attrs = ['_local', '_stop_event']
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
# Re-initialize the '_local' attribute and stop event
|
|
self._local = threading.local()
|
|
self._stop_event = None
|
|
|
|
@property
|
|
def session(self):
|
|
if not hasattr(self._local, 'session'):
|
|
self._local.session = requests.Session()
|
|
self._local.session.headers.update({
|
|
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
|
|
})
|
|
return self._local.session
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str:
|
|
"""Return the provider name."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_available(self) -> bool:
|
|
"""Check if the provider is available and properly configured."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about a domain.
|
|
|
|
Args:
|
|
domain: Domain to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about an IP address.
|
|
|
|
Args:
|
|
ip: IP address to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
def make_request(self, url: str, method: str = "GET",
|
|
params: Optional[Dict[str, Any]] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
target_indicator: str = "",
|
|
max_retries: int = 3) -> Optional[requests.Response]:
|
|
"""
|
|
Make a rate-limited HTTP request with forensic logging and retry logic.
|
|
Now supports cancellation via stop_event from scanner.
|
|
"""
|
|
# Check for cancellation before starting
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
print(f"Request cancelled before start: {url}")
|
|
return None
|
|
|
|
# Create a unique cache key
|
|
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
|
|
cache_path = os.path.join(self.cache_dir, cache_key)
|
|
|
|
# Check cache
|
|
if os.path.exists(cache_path):
|
|
cache_age = time.time() - os.path.getmtime(cache_path)
|
|
if cache_age < self.cache_expiry:
|
|
print(f"Returning cached response for: {url}")
|
|
with open(cache_path, 'r') as f:
|
|
cached_data = json.load(f)
|
|
response = requests.Response()
|
|
response.status_code = cached_data['status_code']
|
|
response._content = cached_data['content'].encode('utf-8')
|
|
response.headers = cached_data['headers']
|
|
return response
|
|
|
|
for attempt in range(max_retries + 1):
|
|
# Check for cancellation before each attempt
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
print(f"Request cancelled during attempt {attempt + 1}: {url}")
|
|
return None
|
|
|
|
# Apply rate limiting (but reduce wait time if cancellation is requested)
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
break
|
|
|
|
self.rate_limiter.wait_if_needed()
|
|
|
|
# Check again after rate limiting
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
print(f"Request cancelled after rate limiting: {url}")
|
|
return None
|
|
|
|
start_time = time.time()
|
|
response = None
|
|
error = None
|
|
|
|
try:
|
|
self.total_requests += 1
|
|
|
|
# Prepare request
|
|
request_headers = self.session.headers.copy()
|
|
if headers:
|
|
request_headers.update(headers)
|
|
|
|
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
|
|
|
|
# Use shorter timeout if termination is requested
|
|
request_timeout = self.timeout
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested
|
|
|
|
# Make request
|
|
if method.upper() == "GET":
|
|
response = self.session.get(
|
|
url,
|
|
params=params,
|
|
headers=request_headers,
|
|
timeout=request_timeout
|
|
)
|
|
elif method.upper() == "POST":
|
|
response = self.session.post(
|
|
url,
|
|
json=params,
|
|
headers=request_headers,
|
|
timeout=request_timeout
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
|
|
print(f"Response status: {response.status_code}")
|
|
response.raise_for_status()
|
|
self.successful_requests += 1
|
|
|
|
# Success - log, cache, and return
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code,
|
|
response_size=len(response.content),
|
|
duration_ms=duration_ms,
|
|
error=None,
|
|
target_indicator=target_indicator
|
|
)
|
|
# Cache the successful response to disk
|
|
with open(cache_path, 'w') as f:
|
|
json.dump({
|
|
'status_code': response.status_code,
|
|
'content': response.text,
|
|
'headers': dict(response.headers)
|
|
}, f)
|
|
return response
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
error = str(e)
|
|
self.failed_requests += 1
|
|
print(f"Request failed (attempt {attempt + 1}): {error}")
|
|
|
|
# Check for cancellation before retrying
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
print(f"Request cancelled, not retrying: {url}")
|
|
break
|
|
|
|
# Check if we should retry
|
|
if attempt < max_retries and self._should_retry(e):
|
|
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
|
|
print(f"Retrying in {backoff_time} seconds...")
|
|
|
|
# Shorter backoff if termination is requested
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
backoff_time = min(0.5, backoff_time)
|
|
|
|
# Sleep with cancellation checking
|
|
sleep_start = time.time()
|
|
while time.time() - sleep_start < backoff_time:
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
print(f"Request cancelled during backoff: {url}")
|
|
return None
|
|
time.sleep(0.1) # Check every 100ms
|
|
continue
|
|
else:
|
|
break
|
|
|
|
except Exception as e:
|
|
error = f"Unexpected error: {str(e)}"
|
|
self.failed_requests += 1
|
|
print(f"Unexpected error: {error}")
|
|
break
|
|
|
|
# All attempts failed - log and return None
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code if response else None,
|
|
response_size=len(response.content) if response else None,
|
|
duration_ms=duration_ms,
|
|
error=error,
|
|
target_indicator=target_indicator
|
|
)
|
|
|
|
return None
|
|
|
|
def set_stop_event(self, stop_event: threading.Event) -> None:
|
|
"""
|
|
Set the stop event for this provider to enable cancellation.
|
|
|
|
Args:
|
|
stop_event: Threading event to signal cancellation
|
|
"""
|
|
self._stop_event = stop_event
|
|
|
|
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
|
|
"""
|
|
Determine if a request should be retried based on the exception.
|
|
|
|
Args:
|
|
exception: The request exception that occurred
|
|
|
|
Returns:
|
|
True if the request should be retried
|
|
"""
|
|
# Retry on connection errors, timeouts, and 5xx server errors
|
|
if isinstance(exception, (requests.exceptions.ConnectionError,
|
|
requests.exceptions.Timeout)):
|
|
return True
|
|
|
|
if isinstance(exception, requests.exceptions.HTTPError):
|
|
if hasattr(exception, 'response') and exception.response:
|
|
# Retry on server errors (5xx) but not client errors (4xx)
|
|
return exception.response.status_code >= 500
|
|
|
|
return False
|
|
|
|
def log_relationship_discovery(self, source_node: str, target_node: str,
|
|
relationship_type: RelationshipType,
|
|
confidence_score: float,
|
|
raw_data: Dict[str, Any],
|
|
discovery_method: str) -> None:
|
|
"""
|
|
Log discovery of a new relationship.
|
|
|
|
Args:
|
|
source_node: Source node identifier
|
|
target_node: Target node identifier
|
|
relationship_type: Type of relationship
|
|
confidence_score: Confidence score
|
|
raw_data: Raw data from provider
|
|
discovery_method: Method used for discovery
|
|
"""
|
|
self.total_relationships_found += 1
|
|
|
|
self.logger.log_relationship_discovery(
|
|
source_node=source_node,
|
|
target_node=target_node,
|
|
relationship_type=relationship_type.relationship_name,
|
|
confidence_score=confidence_score,
|
|
provider=self.name,
|
|
raw_data=raw_data,
|
|
discovery_method=discovery_method
|
|
)
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get provider statistics.
|
|
|
|
Returns:
|
|
Dictionary containing provider performance metrics
|
|
"""
|
|
return {
|
|
'name': self.name,
|
|
'total_requests': self.total_requests,
|
|
'successful_requests': self.successful_requests,
|
|
'failed_requests': self.failed_requests,
|
|
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
|
'relationships_found': self.total_relationships_found,
|
|
'rate_limit': self.rate_limiter.requests_per_minute
|
|
} |