dnsrecon/providers/base_provider.py
2025-09-10 22:34:58 +02:00

403 lines
14 KiB
Python

# dnsrecon/providers/base_provider.py
import time
import requests
import threading
import os
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from core.logger import get_forensic_logger
from core.graph_manager import NodeType, RelationshipType
class RateLimiter:
"""Simple rate limiter for API calls."""
def __init__(self, requests_per_minute: int):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests allowed per minute
"""
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request_time = 0
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Provides common functionality and defines the provider interface.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
"""
Initialize base provider.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit
timeout: Request timeout in seconds
"""
self.name = name
self.rate_limiter = RateLimiter(rate_limit)
self.timeout = timeout
self._local = threading.local()
self.logger = get_forensic_logger()
# Caching configuration
self.cache_dir = '.cache'
self.cache_expiry = 12 * 3600 # 12 hours in seconds
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# Statistics
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
@property
def session(self):
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
})
return self._local.session
@abstractmethod
def get_name(self) -> str:
"""Return the provider name."""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if the provider is available and properly configured."""
pass
@abstractmethod
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query the provider for information about a domain.
Args:
domain: Domain to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
@abstractmethod
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query the provider for information about an IP address.
Args:
ip: IP address to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with forensic logging and retry logic.
Args:
url: Request URL
method: HTTP method
params: Query parameters
headers: Additional headers
target_indicator: The indicator being investigated
max_retries: Maximum number of retry attempts
Returns:
Response object or None if request failed
"""
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
# Check cache
if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry:
print(f"Returning cached response for: {url}")
with open(cache_path, 'r') as f:
cached_data = json.load(f)
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers = cached_data['headers']
return response
for attempt in range(max_retries + 1):
# Apply rate limiting
self.rate_limiter.wait_if_needed()
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
# Prepare request
request_headers = self.session.headers.copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=self.timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=self.timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
# Cache the successful response to disk
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
return response
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
print(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time)
continue
else:
break
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
break
# All attempts failed - log and return None
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
return None
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
Determine if a request should be retried based on the exception.
Args:
exception: The request exception that occurred
Returns:
True if the request should be retried
"""
# Retry on connection errors, timeouts, and 5xx server errors
if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)):
return True
if isinstance(exception, requests.exceptions.HTTPError):
if hasattr(exception, 'response') and exception.response:
# Retry on server errors (5xx) but not client errors (4xx)
return exception.response.status_code >= 500
return False
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: RelationshipType,
confidence_score: float,
raw_data: Dict[str, Any],
discovery_method: str) -> None:
"""
Log discovery of a new relationship.
Args:
source_node: Source node identifier
target_node: Target node identifier
relationship_type: Type of relationship
confidence_score: Confidence score
raw_data: Raw data from provider
discovery_method: Method used for discovery
"""
self.total_relationships_found += 1
self.logger.log_relationship_discovery(
source_node=source_node,
target_node=target_node,
relationship_type=relationship_type.relationship_name,
confidence_score=confidence_score,
provider=self.name,
raw_data=raw_data,
discovery_method=discovery_method
)
def get_statistics(self) -> Dict[str, Any]:
"""
Get provider statistics.
Returns:
Dictionary containing provider performance metrics
"""
return {
'name': self.name,
'total_requests': self.total_requests,
'successful_requests': self.successful_requests,
'failed_requests': self.failed_requests,
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
}
def reset_statistics(self) -> None:
"""Reset provider statistics."""
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
def _extract_domain_from_url(self, url: str) -> Optional[str]:
"""
Extract domain from URL.
Args:
url: URL string
Returns:
Domain name or None if extraction fails
"""
try:
# Remove protocol
if '://' in url:
url = url.split('://', 1)[1]
# Remove path
if '/' in url:
url = url.split('/', 1)[0]
# Remove port
if ':' in url:
url = url.split(':', 1)[0]
return url.lower()
except Exception:
return None
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
Args:
domain: Domain string to validate
Returns:
True if domain appears valid
"""
if not domain or len(domain) > 253:
return False
# Check for valid characters and structure
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
Args:
ip: IP address string to validate
Returns:
True if IP appears valid
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False