326 lines
9.7 KiB
Python
326 lines
9.7 KiB
Python
"""
|
|
Abstract base provider class for DNSRecon data sources.
|
|
Defines the interface and common functionality for all providers.
|
|
"""
|
|
|
|
import time
|
|
import requests
|
|
import threading
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from datetime import datetime
|
|
|
|
from core.logger import get_forensic_logger
|
|
from core.graph_manager import NodeType, RelationshipType
|
|
|
|
|
|
class RateLimiter:
|
|
"""Simple rate limiter for API calls."""
|
|
|
|
def __init__(self, requests_per_minute: int):
|
|
"""
|
|
Initialize rate limiter.
|
|
|
|
Args:
|
|
requests_per_minute: Maximum requests allowed per minute
|
|
"""
|
|
self.requests_per_minute = requests_per_minute
|
|
self.min_interval = 60.0 / requests_per_minute
|
|
self.last_request_time = 0
|
|
|
|
def wait_if_needed(self) -> None:
|
|
"""Wait if necessary to respect rate limits."""
|
|
current_time = time.time()
|
|
time_since_last = current_time - self.last_request_time
|
|
|
|
if time_since_last < self.min_interval:
|
|
sleep_time = self.min_interval - time_since_last
|
|
time.sleep(sleep_time)
|
|
|
|
self.last_request_time = time.time()
|
|
|
|
|
|
class BaseProvider(ABC):
|
|
"""
|
|
Abstract base class for all DNSRecon data providers.
|
|
Provides common functionality and defines the provider interface.
|
|
"""
|
|
|
|
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
|
|
"""
|
|
Initialize base provider.
|
|
|
|
Args:
|
|
name: Provider name for logging
|
|
rate_limit: Requests per minute limit
|
|
timeout: Request timeout in seconds
|
|
"""
|
|
self.name = name
|
|
self.rate_limiter = RateLimiter(rate_limit)
|
|
self.timeout = timeout
|
|
self._local = threading.local()
|
|
self.logger = get_forensic_logger()
|
|
|
|
# Statistics
|
|
self.total_requests = 0
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.total_relationships_found = 0
|
|
|
|
@property
|
|
def session(self):
|
|
if not hasattr(self._local, 'session'):
|
|
self._local.session = requests.Session()
|
|
self._local.session.headers.update({
|
|
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
|
|
})
|
|
return self._local.session
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str:
|
|
"""Return the provider name."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_available(self) -> bool:
|
|
"""Check if the provider is available and properly configured."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about a domain.
|
|
|
|
Args:
|
|
domain: Domain to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
|
"""
|
|
Query the provider for information about an IP address.
|
|
|
|
Args:
|
|
ip: IP address to investigate
|
|
|
|
Returns:
|
|
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
|
"""
|
|
pass
|
|
|
|
def make_request(self, url: str, method: str = "GET",
|
|
params: Optional[Dict[str, Any]] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
target_indicator: str = "") -> Optional[requests.Response]:
|
|
"""
|
|
Make a rate-limited HTTP request with forensic logging.
|
|
|
|
Args:
|
|
url: Request URL
|
|
method: HTTP method
|
|
params: Query parameters
|
|
headers: Additional headers
|
|
target_indicator: The indicator being investigated
|
|
|
|
Returns:
|
|
Response object or None if request failed
|
|
"""
|
|
# Apply rate limiting
|
|
self.rate_limiter.wait_if_needed()
|
|
|
|
start_time = time.time()
|
|
response = None
|
|
error = None
|
|
|
|
try:
|
|
self.total_requests += 1
|
|
|
|
# Prepare request
|
|
request_headers = self.session.headers.copy()
|
|
if headers:
|
|
request_headers.update(headers)
|
|
|
|
print(f"Making {method} request to: {url}")
|
|
|
|
# Make request
|
|
if method.upper() == "GET":
|
|
response = self.session.get(
|
|
url,
|
|
params=params,
|
|
headers=request_headers,
|
|
timeout=self.timeout
|
|
)
|
|
elif method.upper() == "POST":
|
|
response = self.session.post(
|
|
url,
|
|
json=params,
|
|
headers=request_headers,
|
|
timeout=self.timeout
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
|
|
print(f"Response status: {response.status_code}")
|
|
response.raise_for_status()
|
|
self.successful_requests += 1
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
error = str(e)
|
|
self.failed_requests += 1
|
|
print(f"Request failed: {error}")
|
|
|
|
except Exception as e:
|
|
error = f"Unexpected error: {str(e)}"
|
|
self.failed_requests += 1
|
|
print(f"Unexpected error: {error}")
|
|
|
|
# Calculate duration and log request
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
|
|
self.logger.log_api_request(
|
|
provider=self.name,
|
|
url=url,
|
|
method=method.upper(),
|
|
status_code=response.status_code if response else None,
|
|
response_size=len(response.content) if response else None,
|
|
duration_ms=duration_ms,
|
|
error=error,
|
|
target_indicator=target_indicator
|
|
)
|
|
|
|
return response if error is None else None
|
|
|
|
def log_relationship_discovery(self, source_node: str, target_node: str,
|
|
relationship_type: RelationshipType,
|
|
confidence_score: float,
|
|
raw_data: Dict[str, Any],
|
|
discovery_method: str) -> None:
|
|
"""
|
|
Log discovery of a new relationship.
|
|
|
|
Args:
|
|
source_node: Source node identifier
|
|
target_node: Target node identifier
|
|
relationship_type: Type of relationship
|
|
confidence_score: Confidence score
|
|
raw_data: Raw data from provider
|
|
discovery_method: Method used for discovery
|
|
"""
|
|
self.total_relationships_found += 1
|
|
|
|
self.logger.log_relationship_discovery(
|
|
source_node=source_node,
|
|
target_node=target_node,
|
|
relationship_type=relationship_type.relationship_name,
|
|
confidence_score=confidence_score,
|
|
provider=self.name,
|
|
raw_data=raw_data,
|
|
discovery_method=discovery_method
|
|
)
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get provider statistics.
|
|
|
|
Returns:
|
|
Dictionary containing provider performance metrics
|
|
"""
|
|
return {
|
|
'name': self.name,
|
|
'total_requests': self.total_requests,
|
|
'successful_requests': self.successful_requests,
|
|
'failed_requests': self.failed_requests,
|
|
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
|
'relationships_found': self.total_relationships_found,
|
|
'rate_limit': self.rate_limiter.requests_per_minute
|
|
}
|
|
|
|
def reset_statistics(self) -> None:
|
|
"""Reset provider statistics."""
|
|
self.total_requests = 0
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.total_relationships_found = 0
|
|
|
|
def _extract_domain_from_url(self, url: str) -> Optional[str]:
|
|
"""
|
|
Extract domain from URL.
|
|
|
|
Args:
|
|
url: URL string
|
|
|
|
Returns:
|
|
Domain name or None if extraction fails
|
|
"""
|
|
try:
|
|
# Remove protocol
|
|
if '://' in url:
|
|
url = url.split('://', 1)[1]
|
|
|
|
# Remove path
|
|
if '/' in url:
|
|
url = url.split('/', 1)[0]
|
|
|
|
# Remove port
|
|
if ':' in url:
|
|
url = url.split(':', 1)[0]
|
|
|
|
return url.lower()
|
|
|
|
except Exception:
|
|
return None
|
|
|
|
def _is_valid_domain(self, domain: str) -> bool:
|
|
"""
|
|
Basic domain validation.
|
|
|
|
Args:
|
|
domain: Domain string to validate
|
|
|
|
Returns:
|
|
True if domain appears valid
|
|
"""
|
|
if not domain or len(domain) > 253:
|
|
return False
|
|
|
|
# Check for valid characters and structure
|
|
parts = domain.split('.')
|
|
if len(parts) < 2:
|
|
return False
|
|
|
|
for part in parts:
|
|
if not part or len(part) > 63:
|
|
return False
|
|
if not part.replace('-', '').replace('_', '').isalnum():
|
|
return False
|
|
|
|
return True
|
|
|
|
def _is_valid_ip(self, ip: str) -> bool:
|
|
"""
|
|
Basic IP address validation.
|
|
|
|
Args:
|
|
ip: IP address string to validate
|
|
|
|
Returns:
|
|
True if IP appears valid
|
|
"""
|
|
try:
|
|
parts = ip.split('.')
|
|
if len(parts) != 4:
|
|
return False
|
|
|
|
for part in parts:
|
|
num = int(part)
|
|
if not 0 <= num <= 255:
|
|
return False
|
|
|
|
return True
|
|
|
|
except (ValueError, AttributeError):
|
|
return False |