full re implementation
This commit is contained in:
15
providers/__init__.py
Normal file
15
providers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Data provider modules for DNSRecon.
|
||||
Contains implementations for various reconnaissance data sources.
|
||||
"""
|
||||
|
||||
from .base_provider import BaseProvider, RateLimiter
|
||||
from .crtsh_provider import CrtShProvider
|
||||
|
||||
__all__ = [
|
||||
'BaseProvider',
|
||||
'RateLimiter',
|
||||
'CrtShProvider'
|
||||
]
|
||||
|
||||
__version__ = "1.0.0-phase1"
|
||||
326
providers/base_provider.py
Normal file
326
providers/base_provider.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Abstract base provider class for DNSRecon data sources.
|
||||
Defines the interface and common functionality for all providers.
|
||||
"""
|
||||
|
||||
import time
|
||||
import requests
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from core.logger import get_forensic_logger
|
||||
from core.graph_manager import NodeType, RelationshipType
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple rate limiter for API calls."""
|
||||
|
||||
def __init__(self, requests_per_minute: int):
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum requests allowed per minute
|
||||
"""
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.min_interval = 60.0 / requests_per_minute
|
||||
self.last_request_time = 0
|
||||
|
||||
def wait_if_needed(self) -> None:
|
||||
"""Wait if necessary to respect rate limits."""
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self.last_request_time
|
||||
|
||||
if time_since_last < self.min_interval:
|
||||
sleep_time = self.min_interval - time_since_last
|
||||
time.sleep(sleep_time)
|
||||
|
||||
self.last_request_time = time.time()
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for all DNSRecon data providers.
|
||||
Provides common functionality and defines the provider interface.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
|
||||
"""
|
||||
Initialize base provider.
|
||||
|
||||
Args:
|
||||
name: Provider name for logging
|
||||
rate_limit: Requests per minute limit
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.name = name
|
||||
self.rate_limiter = RateLimiter(rate_limit)
|
||||
self.timeout = timeout
|
||||
self._local = threading.local()
|
||||
self.logger = get_forensic_logger()
|
||||
|
||||
# Statistics
|
||||
self.total_requests = 0
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_relationships_found = 0
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if not hasattr(self._local, 'session'):
|
||||
self._local.session = requests.Session()
|
||||
self._local.session.headers.update({
|
||||
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
|
||||
})
|
||||
return self._local.session
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if the provider is available and properly configured."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Query the provider for information about a domain.
|
||||
|
||||
Args:
|
||||
domain: Domain to investigate
|
||||
|
||||
Returns:
|
||||
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Query the provider for information about an IP address.
|
||||
|
||||
Args:
|
||||
ip: IP address to investigate
|
||||
|
||||
Returns:
|
||||
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
|
||||
"""
|
||||
pass
|
||||
|
||||
def make_request(self, url: str, method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
target_indicator: str = "") -> Optional[requests.Response]:
|
||||
"""
|
||||
Make a rate-limited HTTP request with forensic logging.
|
||||
|
||||
Args:
|
||||
url: Request URL
|
||||
method: HTTP method
|
||||
params: Query parameters
|
||||
headers: Additional headers
|
||||
target_indicator: The indicator being investigated
|
||||
|
||||
Returns:
|
||||
Response object or None if request failed
|
||||
"""
|
||||
# Apply rate limiting
|
||||
self.rate_limiter.wait_if_needed()
|
||||
|
||||
start_time = time.time()
|
||||
response = None
|
||||
error = None
|
||||
|
||||
try:
|
||||
self.total_requests += 1
|
||||
|
||||
# Prepare request
|
||||
request_headers = self.session.headers.copy()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
print(f"Making {method} request to: {url}")
|
||||
|
||||
# Make request
|
||||
if method.upper() == "GET":
|
||||
response = self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = self.session.post(
|
||||
url,
|
||||
json=params,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
print(f"Response status: {response.status_code}")
|
||||
response.raise_for_status()
|
||||
self.successful_requests += 1
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error = str(e)
|
||||
self.failed_requests += 1
|
||||
print(f"Request failed: {error}")
|
||||
|
||||
except Exception as e:
|
||||
error = f"Unexpected error: {str(e)}"
|
||||
self.failed_requests += 1
|
||||
print(f"Unexpected error: {error}")
|
||||
|
||||
# Calculate duration and log request
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
self.logger.log_api_request(
|
||||
provider=self.name,
|
||||
url=url,
|
||||
method=method.upper(),
|
||||
status_code=response.status_code if response else None,
|
||||
response_size=len(response.content) if response else None,
|
||||
duration_ms=duration_ms,
|
||||
error=error,
|
||||
target_indicator=target_indicator
|
||||
)
|
||||
|
||||
return response if error is None else None
|
||||
|
||||
def log_relationship_discovery(self, source_node: str, target_node: str,
|
||||
relationship_type: RelationshipType,
|
||||
confidence_score: float,
|
||||
raw_data: Dict[str, Any],
|
||||
discovery_method: str) -> None:
|
||||
"""
|
||||
Log discovery of a new relationship.
|
||||
|
||||
Args:
|
||||
source_node: Source node identifier
|
||||
target_node: Target node identifier
|
||||
relationship_type: Type of relationship
|
||||
confidence_score: Confidence score
|
||||
raw_data: Raw data from provider
|
||||
discovery_method: Method used for discovery
|
||||
"""
|
||||
self.total_relationships_found += 1
|
||||
|
||||
self.logger.log_relationship_discovery(
|
||||
source_node=source_node,
|
||||
target_node=target_node,
|
||||
relationship_type=relationship_type.relationship_name,
|
||||
confidence_score=confidence_score,
|
||||
provider=self.name,
|
||||
raw_data=raw_data,
|
||||
discovery_method=discovery_method
|
||||
)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get provider statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing provider performance metrics
|
||||
"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'total_requests': self.total_requests,
|
||||
'successful_requests': self.successful_requests,
|
||||
'failed_requests': self.failed_requests,
|
||||
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
||||
'relationships_found': self.total_relationships_found,
|
||||
'rate_limit': self.rate_limiter.requests_per_minute
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""Reset provider statistics."""
|
||||
self.total_requests = 0
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_relationships_found = 0
|
||||
|
||||
def _extract_domain_from_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Extract domain from URL.
|
||||
|
||||
Args:
|
||||
url: URL string
|
||||
|
||||
Returns:
|
||||
Domain name or None if extraction fails
|
||||
"""
|
||||
try:
|
||||
# Remove protocol
|
||||
if '://' in url:
|
||||
url = url.split('://', 1)[1]
|
||||
|
||||
# Remove path
|
||||
if '/' in url:
|
||||
url = url.split('/', 1)[0]
|
||||
|
||||
# Remove port
|
||||
if ':' in url:
|
||||
url = url.split(':', 1)[0]
|
||||
|
||||
return url.lower()
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _is_valid_domain(self, domain: str) -> bool:
|
||||
"""
|
||||
Basic domain validation.
|
||||
|
||||
Args:
|
||||
domain: Domain string to validate
|
||||
|
||||
Returns:
|
||||
True if domain appears valid
|
||||
"""
|
||||
if not domain or len(domain) > 253:
|
||||
return False
|
||||
|
||||
# Check for valid characters and structure
|
||||
parts = domain.split('.')
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
for part in parts:
|
||||
if not part or len(part) > 63:
|
||||
return False
|
||||
if not part.replace('-', '').replace('_', '').isalnum():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_valid_ip(self, ip: str) -> bool:
|
||||
"""
|
||||
Basic IP address validation.
|
||||
|
||||
Args:
|
||||
ip: IP address string to validate
|
||||
|
||||
Returns:
|
||||
True if IP appears valid
|
||||
"""
|
||||
try:
|
||||
parts = ip.split('.')
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
|
||||
for part in parts:
|
||||
num = int(part)
|
||||
if not 0 <= num <= 255:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
272
providers/crtsh_provider.py
Normal file
272
providers/crtsh_provider.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Certificate Transparency provider using crt.sh.
|
||||
Discovers domain relationships through certificate SAN analysis.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any, Tuple, Set
|
||||
from urllib.parse import quote
|
||||
|
||||
from .base_provider import BaseProvider
|
||||
from core.graph_manager import RelationshipType
|
||||
|
||||
|
||||
class CrtShProvider(BaseProvider):
|
||||
"""
|
||||
Provider for querying crt.sh certificate transparency database.
|
||||
Discovers domain relationships through certificate Subject Alternative Names (SANs).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CrtSh provider with appropriate rate limiting."""
|
||||
super().__init__(
|
||||
name="crtsh",
|
||||
rate_limit=60, # Be respectful to the free service
|
||||
timeout=30
|
||||
)
|
||||
self.base_url = "https://crt.sh/"
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "crtsh"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Check if the provider is configured to be used.
|
||||
This method is intentionally simple and does not perform a network request
|
||||
to avoid blocking application startup.
|
||||
"""
|
||||
return True
|
||||
|
||||
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Query crt.sh for certificates containing the domain.
|
||||
|
||||
Args:
|
||||
domain: Domain to investigate
|
||||
|
||||
Returns:
|
||||
List of relationships discovered from certificate analysis
|
||||
"""
|
||||
if not self._is_valid_domain(domain):
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
|
||||
try:
|
||||
# Query crt.sh for certificates
|
||||
url = f"{self.base_url}?q={quote(domain)}&output=json"
|
||||
response = self.make_request(url, target_indicator=domain)
|
||||
|
||||
if not response or response.status_code != 200:
|
||||
return []
|
||||
|
||||
certificates = response.json()
|
||||
|
||||
if not certificates:
|
||||
return []
|
||||
|
||||
# Process certificates to extract relationships
|
||||
seen_certificates = set()
|
||||
|
||||
for cert_data in certificates:
|
||||
cert_id = cert_data.get('id')
|
||||
if not cert_id or cert_id in seen_certificates:
|
||||
continue
|
||||
|
||||
seen_certificates.add(cert_id)
|
||||
|
||||
# Extract domains from certificate
|
||||
cert_domains = self._extract_domains_from_certificate(cert_data)
|
||||
|
||||
if domain in cert_domains and len(cert_domains) > 1:
|
||||
# Create relationships between domains found in the same certificate
|
||||
for related_domain in cert_domains:
|
||||
if related_domain != domain and self._is_valid_domain(related_domain):
|
||||
# Create SAN relationship
|
||||
raw_data = {
|
||||
'certificate_id': cert_id,
|
||||
'issuer': cert_data.get('issuer_name', ''),
|
||||
'not_before': cert_data.get('not_before', ''),
|
||||
'not_after': cert_data.get('not_after', ''),
|
||||
'serial_number': cert_data.get('serial_number', ''),
|
||||
'all_domains': list(cert_domains)
|
||||
}
|
||||
|
||||
relationships.append((
|
||||
domain,
|
||||
related_domain,
|
||||
RelationshipType.SAN_CERTIFICATE,
|
||||
RelationshipType.SAN_CERTIFICATE.default_confidence,
|
||||
raw_data
|
||||
))
|
||||
|
||||
# Log the discovery
|
||||
self.log_relationship_discovery(
|
||||
source_node=domain,
|
||||
target_node=related_domain,
|
||||
relationship_type=RelationshipType.SAN_CERTIFICATE,
|
||||
confidence_score=RelationshipType.SAN_CERTIFICATE.default_confidence,
|
||||
raw_data=raw_data,
|
||||
discovery_method="certificate_san_analysis"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.logger.error(f"Failed to parse JSON response from crt.sh: {e}")
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}")
|
||||
|
||||
return relationships
|
||||
|
||||
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Query crt.sh for certificates containing the IP address.
|
||||
Note: crt.sh doesn't typically index by IP, so this returns empty results.
|
||||
|
||||
Args:
|
||||
ip: IP address to investigate
|
||||
|
||||
Returns:
|
||||
Empty list (crt.sh doesn't support IP-based certificate queries effectively)
|
||||
"""
|
||||
# crt.sh doesn't effectively support IP-based certificate queries
|
||||
# This would require parsing certificate details for IP SANs, which is complex
|
||||
return []
|
||||
|
||||
def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Extract all domains from certificate data.
|
||||
|
||||
Args:
|
||||
cert_data: Certificate data from crt.sh API
|
||||
|
||||
Returns:
|
||||
Set of unique domain names found in the certificate
|
||||
"""
|
||||
domains = set()
|
||||
|
||||
# Extract from common name
|
||||
common_name = cert_data.get('common_name', '')
|
||||
if common_name:
|
||||
cleaned_cn = self._clean_domain_name(common_name)
|
||||
if cleaned_cn and self._is_valid_domain(cleaned_cn):
|
||||
domains.add(cleaned_cn)
|
||||
|
||||
# Extract from name_value field (contains SANs)
|
||||
name_value = cert_data.get('name_value', '')
|
||||
if name_value:
|
||||
# Split by newlines and clean each domain
|
||||
for line in name_value.split('\n'):
|
||||
cleaned_domain = self._clean_domain_name(line.strip())
|
||||
if cleaned_domain and self._is_valid_domain(cleaned_domain):
|
||||
domains.add(cleaned_domain)
|
||||
|
||||
return domains
|
||||
|
||||
def _clean_domain_name(self, domain_name: str) -> str:
|
||||
"""
|
||||
Clean and normalize domain name from certificate data.
|
||||
|
||||
Args:
|
||||
domain_name: Raw domain name from certificate
|
||||
|
||||
Returns:
|
||||
Cleaned domain name or empty string if invalid
|
||||
"""
|
||||
if not domain_name:
|
||||
return ""
|
||||
|
||||
# Remove common prefixes and clean up
|
||||
domain = domain_name.strip().lower()
|
||||
|
||||
# Remove protocol if present
|
||||
if domain.startswith(('http://', 'https://')):
|
||||
domain = domain.split('://', 1)[1]
|
||||
|
||||
# Remove path if present
|
||||
if '/' in domain:
|
||||
domain = domain.split('/', 1)[0]
|
||||
|
||||
# Remove port if present
|
||||
if ':' in domain and not domain.count(':') > 1: # Avoid breaking IPv6
|
||||
domain = domain.split(':', 1)[0]
|
||||
|
||||
# Handle wildcard domains
|
||||
if domain.startswith('*.'):
|
||||
domain = domain[2:]
|
||||
|
||||
# Remove any remaining invalid characters
|
||||
domain = re.sub(r'[^\w\-\.]', '', domain)
|
||||
|
||||
# Ensure it's not empty and doesn't start/end with dots or hyphens
|
||||
if domain and not domain.startswith(('.', '-')) and not domain.endswith(('.', '-')):
|
||||
return domain
|
||||
|
||||
return ""
|
||||
|
||||
def get_certificate_details(self, certificate_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get detailed information about a specific certificate.
|
||||
|
||||
Args:
|
||||
certificate_id: Certificate ID from crt.sh
|
||||
|
||||
Returns:
|
||||
Dictionary containing certificate details
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}?id={certificate_id}&output=json"
|
||||
response = self.make_request(url, target_indicator=f"cert_{certificate_id}")
|
||||
|
||||
if response and response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error fetching certificate details for {certificate_id}: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
def search_certificates_by_serial(self, serial_number: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for certificates by serial number.
|
||||
|
||||
Args:
|
||||
serial_number: Certificate serial number
|
||||
|
||||
Returns:
|
||||
List of matching certificates
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}?serial={quote(serial_number)}&output=json"
|
||||
response = self.make_request(url, target_indicator=f"serial_{serial_number}")
|
||||
|
||||
if response and response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error searching certificates by serial {serial_number}: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def get_issuer_certificates(self, issuer_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get certificates issued by a specific CA.
|
||||
|
||||
Args:
|
||||
issuer_name: Certificate Authority name
|
||||
|
||||
Returns:
|
||||
List of certificates from the specified issuer
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}?issuer={quote(issuer_name)}&output=json"
|
||||
response = self.make_request(url, target_indicator=f"issuer_{issuer_name}")
|
||||
|
||||
if response and response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error fetching certificates for issuer {issuer_name}: {e}")
|
||||
|
||||
return []
|
||||
0
providers/dns_provider.py
Normal file
0
providers/dns_provider.py
Normal file
Reference in New Issue
Block a user