266 lines
8.9 KiB
Python
266 lines
8.9 KiB
Python
# dnsrecon/providers/base_provider.py
|
|
|
|
import time
|
|
import requests
|
|
import threading
|
|
import redis
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
from core.logger import get_forensic_logger
|
|
from core.rate_limiter import GlobalRateLimiter
|
|
|
|
|
|
class BaseProvider(ABC):
|
|
"""
|
|
Abstract base class for all DNSRecon data providers.
|
|
Now supports session-specific configuration.
|
|
"""
|
|
|
|
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
|
"""
|
|
Initialize base provider with session-specific configuration.
|
|
|
|
Args:
|
|
name: Provider name for logging
|
|
rate_limit: Requests per minute limit (default override)
|
|
timeout: Request timeout in seconds
|
|
session_config: Session-specific configuration
|
|
"""
|
|
# Use session config if provided, otherwise fall back to global config
|
|
if session_config is not None:
|
|
self.config = session_config
|
|
actual_rate_limit = self.config.get_rate_limit(name)
|
|
actual_timeout = self.config.default_timeout
|
|
else:
|
|
# Fallback to global config for backwards compatibility
|
|
from config import config as global_config
|
|
self.config = global_config
|
|
actual_rate_limit = rate_limit
|
|
actual_timeout = timeout
|
|
|
|
self.name = name
|
|
self.timeout = actual_timeout
|
|
self._local = threading.local()
|
|
self.logger = get_forensic_logger()
|
|
self._stop_event = None
|
|
|
|
# Statistics (per provider instance)
|
|
self.total_requests = 0
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.total_relationships_found = 0
|
|
|
|
def __getstate__(self):
|
|
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
|
state = self.__dict__.copy()
|
|
# Exclude the unpickleable '_local' attribute and stop event
|
|
unpicklable_attrs = ['_local', '_stop_event']
|
|
for attr in unpicklable_attrs:
|
|
if attr in state:
|
|
del state[attr]
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore BaseProvider after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
# Re-initialize the '_local' attribute and stop event
|
|
self._local = threading.local()
|
|
self._stop_event = None
|
|
|
|
@property
|
|
def session(self):
|
|
if not hasattr(self._local, 'session'):
|
|
self._local.session = requests.Session()
|
|
self._local.session.headers.update({
|
|
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
|
|
})
|
|
return self._local.session
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str:
|
|
"""Return the provider name."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_display_name(self) -> str:
|
|
"""Return the provider display name for the UI."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def requires_api_key(self) -> bool:
|
|
"""Return True if the provider requires an API key."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_eligibility(self) -> Dict[str, bool]:
|
|
"""Return a dictionary indicating if the provider can query domains and/or IPs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_available(self) -> bool:
|
|
"""Check if the provider is available and properly configured."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about a domain.
|
|
|
|
Args:
|
|
domain: Domain to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about an IP address.
|
|
|
|
Args:
|
|
ip: IP address to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
def make_request(self, url: str, method: str = "GET",
|
|
params: Optional[Dict[str, Any]] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
target_indicator: str = "") -> Optional[requests.Response]:
|
|
"""
|
|
Make a rate-limited HTTP request.
|
|
"""
|
|
if self._is_stop_requested():
|
|
print(f"Request cancelled before start: {url}")
|
|
return None
|
|
|
|
start_time = time.time()
|
|
response = None
|
|
error = None
|
|
|
|
try:
|
|
self.total_requests += 1
|
|
|
|
request_headers = dict(self.session.headers).copy()
|
|
if headers:
|
|
request_headers.update(headers)
|
|
|
|
print(f"Making {method} request to: {url}")
|
|
|
|
if method.upper() == "GET":
|
|
response = self.session.get(
|
|
url,
|
|
params=params,
|
|
headers=request_headers,
|
|
timeout=self.timeout
|
|
)
|
|
elif method.upper() == "POST":
|
|
response = self.session.post(
|
|
url,
|
|
json=params,
|
|
headers=request_headers,
|
|
timeout=self.timeout
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
|
|
print(f"Response status: {response.status_code}")
|
|
response.raise_for_status()
|
|
self.successful_requests += 1
|
|
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code,
|
|
response_size=len(response.content),
|
|
duration_ms=duration_ms,
|
|
error=None,
|
|
target_indicator=target_indicator
|
|
)
|
|
|
|
return response
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
error = str(e)
|
|
self.failed_requests += 1
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code if response else None,
|
|
response_size=len(response.content) if response else None,
|
|
duration_ms=duration_ms,
|
|
error=error,
|
|
target_indicator=target_indicator
|
|
)
|
|
raise e
|
|
|
|
def _is_stop_requested(self) -> bool:
|
|
"""
|
|
Enhanced stop signal checking that handles both local and Redis-based signals.
|
|
"""
|
|
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
|
return True
|
|
return False
|
|
|
|
def set_stop_event(self, stop_event: threading.Event) -> None:
|
|
"""
|
|
Set the stop event for this provider to enable cancellation.
|
|
|
|
Args:
|
|
stop_event: Threading event to signal cancellation
|
|
"""
|
|
self._stop_event = stop_event
|
|
|
|
def log_relationship_discovery(self, source_node: str, target_node: str,
|
|
relationship_type: str,
|
|
confidence_score: float,
|
|
raw_data: Dict[str, Any],
|
|
discovery_method: str) -> None:
|
|
"""
|
|
Log discovery of a new relationship.
|
|
|
|
Args:
|
|
source_node: Source node identifier
|
|
target_node: Target node identifier
|
|
relationship_type: Type of relationship
|
|
confidence_score: Confidence score
|
|
raw_data: Raw data from provider
|
|
discovery_method: Method used for discovery
|
|
"""
|
|
self.total_relationships_found += 1
|
|
|
|
self.logger.log_relationship_discovery(
|
|
source_node=source_node,
|
|
target_node=target_node,
|
|
relationship_type=relationship_type,
|
|
confidence_score=confidence_score,
|
|
provider=self.name,
|
|
raw_data=raw_data,
|
|
discovery_method=discovery_method
|
|
)
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get provider statistics.
|
|
|
|
Returns:
|
|
Dictionary containing provider performance metrics
|
|
"""
|
|
return {
|
|
'name': self.name,
|
|
'total_requests': self.total_requests,
|
|
'successful_requests': self.successful_requests,
|
|
'failed_requests': self.failed_requests,
|
|
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
|
'relationships_found': self.total_relationships_found,
|
|
'rate_limit': self.config.get_rate_limit(self.name)
|
|
} |