dnsrecon/providers/base_provider.py
overcuriousity 7fe7ca41ba it
2025-09-14 17:40:18 +02:00

329 lines
11 KiB
Python

# dnsrecon/providers/base_provider.py
import time
import requests
import threading
import os
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from core.logger import get_forensic_logger
class RateLimiter:
"""Simple rate limiter for API calls."""
def __init__(self, requests_per_minute: int):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests allowed per minute
"""
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request_time = 0
def __getstate__(self):
"""RateLimiter is fully picklable, return full state."""
return self.__dict__.copy()
def __setstate__(self, state):
"""Restore RateLimiter state."""
self.__dict__.update(state)
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Now supports session-specific configuration.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
"""
Initialize base provider with session-specific configuration.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit (default override)
timeout: Request timeout in seconds
session_config: Session-specific configuration
"""
# Use session config if provided, otherwise fall back to global config
if session_config is not None:
self.config = session_config
actual_rate_limit = self.config.get_rate_limit(name)
actual_timeout = self.config.default_timeout
else:
# Fallback to global config for backwards compatibility
from config import config as global_config
self.config = global_config
actual_rate_limit = rate_limit
actual_timeout = timeout
self.name = name
self.rate_limiter = RateLimiter(actual_rate_limit)
self.timeout = actual_timeout
self._local = threading.local()
self.logger = get_forensic_logger()
self._stop_event = None
# Caching configuration (per session)
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
self.cache_expiry = self.config.cache_expiry_hours * 3600
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# Statistics (per provider instance)
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
unpicklable_attrs = ['_local', '_stop_event']
for attr in unpicklable_attrs:
if attr in state:
del state[attr]
return state
def __setstate__(self, state):
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
self.__dict__.update(state)
# Re-initialize the '_local' attribute and stop event
self._local = threading.local()
self._stop_event = None
@property
def session(self):
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
})
return self._local.session
@abstractmethod
def get_name(self) -> str:
"""Return the provider name."""
pass
@abstractmethod
def get_display_name(self) -> str:
"""Return the provider display name for the UI."""
pass
@abstractmethod
def requires_api_key(self) -> bool:
"""Return True if the provider requires an API key."""
pass
@abstractmethod
def get_eligibility(self) -> Dict[str, bool]:
"""Return a dictionary indicating if the provider can query domains and/or IPs."""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if the provider is available and properly configured."""
pass
@abstractmethod
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query the provider for information about a domain.
Args:
domain: Domain to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
@abstractmethod
def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query the provider for information about an IP address.
Args:
ip: IP address to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "") -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request.
"""
if self._is_stop_requested():
print(f"Request cancelled before start: {url}")
return None
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry:
print(f"Returning cached response for: {url}")
with open(cache_path, 'r') as f:
cached_data = json.load(f)
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers = cached_data['headers']
return response
self.rate_limiter.wait_if_needed()
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
request_headers = dict(self.session.headers).copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url}")
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=self.timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=self.timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
return response
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
raise e
def _is_stop_requested(self) -> bool:
"""
Enhanced stop signal checking that handles both local and Redis-based signals.
"""
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
return True
return False
def set_stop_event(self, stop_event: threading.Event) -> None:
"""
Set the stop event for this provider to enable cancellation.
Args:
stop_event: Threading event to signal cancellation
"""
self._stop_event = stop_event
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str,
confidence_score: float,
raw_data: Dict[str, Any],
discovery_method: str) -> None:
"""
Log discovery of a new relationship.
Args:
source_node: Source node identifier
target_node: Target node identifier
relationship_type: Type of relationship
confidence_score: Confidence score
raw_data: Raw data from provider
discovery_method: Method used for discovery
"""
self.total_relationships_found += 1
self.logger.log_relationship_discovery(
source_node=source_node,
target_node=target_node,
relationship_type=relationship_type,
confidence_score=confidence_score,
provider=self.name,
raw_data=raw_data,
discovery_method=discovery_method
)
def get_statistics(self) -> Dict[str, Any]:
"""
Get provider statistics.
Returns:
Dictionary containing provider performance metrics
"""
return {
'name': self.name,
'total_requests': self.total_requests,
'successful_requests': self.successful_requests,
'failed_requests': self.failed_requests,
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
}