full re implementation

This commit is contained in:
overcuriousity
2025-09-10 13:53:32 +02:00
parent 29e36e34be
commit 696cec0723
32 changed files with 4731 additions and 7955 deletions

15
providers/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
"""
Data provider modules for DNSRecon.
Contains implementations for various reconnaissance data sources.
"""
from .base_provider import BaseProvider, RateLimiter
from .crtsh_provider import CrtShProvider
__all__ = [
'BaseProvider',
'RateLimiter',
'CrtShProvider'
]
__version__ = "1.0.0-phase1"

326
providers/base_provider.py Normal file
View File

@@ -0,0 +1,326 @@
"""
Abstract base provider class for DNSRecon data sources.
Defines the interface and common functionality for all providers.
"""
import time
import requests
import threading
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from core.logger import get_forensic_logger
from core.graph_manager import NodeType, RelationshipType
class RateLimiter:
"""Simple rate limiter for API calls."""
def __init__(self, requests_per_minute: int):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests allowed per minute
"""
self.requests_per_minute = requests_per_minute
self.min_interval = 60.0 / requests_per_minute
self.last_request_time = 0
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_interval:
sleep_time = self.min_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Provides common functionality and defines the provider interface.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
"""
Initialize base provider.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit
timeout: Request timeout in seconds
"""
self.name = name
self.rate_limiter = RateLimiter(rate_limit)
self.timeout = timeout
self._local = threading.local()
self.logger = get_forensic_logger()
# Statistics
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
@property
def session(self):
if not hasattr(self._local, 'session'):
self._local.session = requests.Session()
self._local.session.headers.update({
'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
})
return self._local.session
@abstractmethod
def get_name(self) -> str:
"""Return the provider name."""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if the provider is available and properly configured."""
pass
@abstractmethod
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query the provider for information about a domain.
Args:
domain: Domain to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
@abstractmethod
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query the provider for information about an IP address.
Args:
ip: IP address to investigate
Returns:
List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
"""
pass
def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "") -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with forensic logging.
Args:
url: Request URL
method: HTTP method
params: Query parameters
headers: Additional headers
target_indicator: The indicator being investigated
Returns:
Response object or None if request failed
"""
# Apply rate limiting
self.rate_limiter.wait_if_needed()
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
# Prepare request
request_headers = self.session.headers.copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url}")
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=self.timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=self.timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
print(f"Request failed: {error}")
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
# Calculate duration and log request
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
return response if error is None else None
def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: RelationshipType,
confidence_score: float,
raw_data: Dict[str, Any],
discovery_method: str) -> None:
"""
Log discovery of a new relationship.
Args:
source_node: Source node identifier
target_node: Target node identifier
relationship_type: Type of relationship
confidence_score: Confidence score
raw_data: Raw data from provider
discovery_method: Method used for discovery
"""
self.total_relationships_found += 1
self.logger.log_relationship_discovery(
source_node=source_node,
target_node=target_node,
relationship_type=relationship_type.relationship_name,
confidence_score=confidence_score,
provider=self.name,
raw_data=raw_data,
discovery_method=discovery_method
)
def get_statistics(self) -> Dict[str, Any]:
"""
Get provider statistics.
Returns:
Dictionary containing provider performance metrics
"""
return {
'name': self.name,
'total_requests': self.total_requests,
'successful_requests': self.successful_requests,
'failed_requests': self.failed_requests,
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
}
def reset_statistics(self) -> None:
"""Reset provider statistics."""
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
def _extract_domain_from_url(self, url: str) -> Optional[str]:
"""
Extract domain from URL.
Args:
url: URL string
Returns:
Domain name or None if extraction fails
"""
try:
# Remove protocol
if '://' in url:
url = url.split('://', 1)[1]
# Remove path
if '/' in url:
url = url.split('/', 1)[0]
# Remove port
if ':' in url:
url = url.split(':', 1)[0]
return url.lower()
except Exception:
return None
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
Args:
domain: Domain string to validate
Returns:
True if domain appears valid
"""
if not domain or len(domain) > 253:
return False
# Check for valid characters and structure
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
Args:
ip: IP address string to validate
Returns:
True if IP appears valid
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False

272
providers/crtsh_provider.py Normal file
View File

@@ -0,0 +1,272 @@
"""
Certificate Transparency provider using crt.sh.
Discovers domain relationships through certificate SAN analysis.
"""
import json
import re
from typing import List, Dict, Any, Tuple, Set
from urllib.parse import quote
from .base_provider import BaseProvider
from core.graph_manager import RelationshipType
class CrtShProvider(BaseProvider):
"""
Provider for querying crt.sh certificate transparency database.
Discovers domain relationships through certificate Subject Alternative Names (SANs).
"""
def __init__(self):
"""Initialize CrtSh provider with appropriate rate limiting."""
super().__init__(
name="crtsh",
rate_limit=60, # Be respectful to the free service
timeout=30
)
self.base_url = "https://crt.sh/"
def get_name(self) -> str:
"""Return the provider name."""
return "crtsh"
def is_available(self) -> bool:
"""
Check if the provider is configured to be used.
This method is intentionally simple and does not perform a network request
to avoid blocking application startup.
"""
return True
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the domain.
Args:
domain: Domain to investigate
Returns:
List of relationships discovered from certificate analysis
"""
if not self._is_valid_domain(domain):
return []
relationships = []
try:
# Query crt.sh for certificates
url = f"{self.base_url}?q={quote(domain)}&output=json"
response = self.make_request(url, target_indicator=domain)
if not response or response.status_code != 200:
return []
certificates = response.json()
if not certificates:
return []
# Process certificates to extract relationships
seen_certificates = set()
for cert_data in certificates:
cert_id = cert_data.get('id')
if not cert_id or cert_id in seen_certificates:
continue
seen_certificates.add(cert_id)
# Extract domains from certificate
cert_domains = self._extract_domains_from_certificate(cert_data)
if domain in cert_domains and len(cert_domains) > 1:
# Create relationships between domains found in the same certificate
for related_domain in cert_domains:
if related_domain != domain and self._is_valid_domain(related_domain):
# Create SAN relationship
raw_data = {
'certificate_id': cert_id,
'issuer': cert_data.get('issuer_name', ''),
'not_before': cert_data.get('not_before', ''),
'not_after': cert_data.get('not_after', ''),
'serial_number': cert_data.get('serial_number', ''),
'all_domains': list(cert_domains)
}
relationships.append((
domain,
related_domain,
RelationshipType.SAN_CERTIFICATE,
RelationshipType.SAN_CERTIFICATE.default_confidence,
raw_data
))
# Log the discovery
self.log_relationship_discovery(
source_node=domain,
target_node=related_domain,
relationship_type=RelationshipType.SAN_CERTIFICATE,
confidence_score=RelationshipType.SAN_CERTIFICATE.default_confidence,
raw_data=raw_data,
discovery_method="certificate_san_analysis"
)
except json.JSONDecodeError as e:
self.logger.logger.error(f"Failed to parse JSON response from crt.sh: {e}")
except Exception as e:
self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}")
return relationships
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the IP address.
Note: crt.sh doesn't typically index by IP, so this returns empty results.
Args:
ip: IP address to investigate
Returns:
Empty list (crt.sh doesn't support IP-based certificate queries effectively)
"""
# crt.sh doesn't effectively support IP-based certificate queries
# This would require parsing certificate details for IP SANs, which is complex
return []
def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]:
"""
Extract all domains from certificate data.
Args:
cert_data: Certificate data from crt.sh API
Returns:
Set of unique domain names found in the certificate
"""
domains = set()
# Extract from common name
common_name = cert_data.get('common_name', '')
if common_name:
cleaned_cn = self._clean_domain_name(common_name)
if cleaned_cn and self._is_valid_domain(cleaned_cn):
domains.add(cleaned_cn)
# Extract from name_value field (contains SANs)
name_value = cert_data.get('name_value', '')
if name_value:
# Split by newlines and clean each domain
for line in name_value.split('\n'):
cleaned_domain = self._clean_domain_name(line.strip())
if cleaned_domain and self._is_valid_domain(cleaned_domain):
domains.add(cleaned_domain)
return domains
def _clean_domain_name(self, domain_name: str) -> str:
"""
Clean and normalize domain name from certificate data.
Args:
domain_name: Raw domain name from certificate
Returns:
Cleaned domain name or empty string if invalid
"""
if not domain_name:
return ""
# Remove common prefixes and clean up
domain = domain_name.strip().lower()
# Remove protocol if present
if domain.startswith(('http://', 'https://')):
domain = domain.split('://', 1)[1]
# Remove path if present
if '/' in domain:
domain = domain.split('/', 1)[0]
# Remove port if present
if ':' in domain and not domain.count(':') > 1: # Avoid breaking IPv6
domain = domain.split(':', 1)[0]
# Handle wildcard domains
if domain.startswith('*.'):
domain = domain[2:]
# Remove any remaining invalid characters
domain = re.sub(r'[^\w\-\.]', '', domain)
# Ensure it's not empty and doesn't start/end with dots or hyphens
if domain and not domain.startswith(('.', '-')) and not domain.endswith(('.', '-')):
return domain
return ""
def get_certificate_details(self, certificate_id: str) -> Dict[str, Any]:
"""
Get detailed information about a specific certificate.
Args:
certificate_id: Certificate ID from crt.sh
Returns:
Dictionary containing certificate details
"""
try:
url = f"{self.base_url}?id={certificate_id}&output=json"
response = self.make_request(url, target_indicator=f"cert_{certificate_id}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error fetching certificate details for {certificate_id}: {e}")
return {}
def search_certificates_by_serial(self, serial_number: str) -> List[Dict[str, Any]]:
"""
Search for certificates by serial number.
Args:
serial_number: Certificate serial number
Returns:
List of matching certificates
"""
try:
url = f"{self.base_url}?serial={quote(serial_number)}&output=json"
response = self.make_request(url, target_indicator=f"serial_{serial_number}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error searching certificates by serial {serial_number}: {e}")
return []
def get_issuer_certificates(self, issuer_name: str) -> List[Dict[str, Any]]:
"""
Get certificates issued by a specific CA.
Args:
issuer_name: Certificate Authority name
Returns:
List of certificates from the specified issuer
"""
try:
url = f"{self.base_url}?issuer={quote(issuer_name)}&output=json"
response = self.make_request(url, target_indicator=f"issuer_{issuer_name}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error fetching certificates for issuer {issuer_name}: {e}")
return []

View File