This commit is contained in:
overcuriousity
2025-09-11 14:01:15 +02:00
parent 2d485c5703
commit d3e1fcf35f
18 changed files with 1806 additions and 843 deletions

View File

@@ -7,10 +7,9 @@ import os
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from core.logger import get_forensic_logger
from core.graph_manager import NodeType, RelationshipType
from core.graph_manager import RelationshipType
class RateLimiter:
@@ -42,36 +41,52 @@ class RateLimiter:
class BaseProvider(ABC):
"""
Abstract base class for all DNSRecon data providers.
Provides common functionality and defines the provider interface.
Now supports session-specific configuration.
"""
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
"""
Initialize base provider.
Initialize base provider with session-specific configuration.
Args:
name: Provider name for logging
rate_limit: Requests per minute limit
rate_limit: Requests per minute limit (default override)
timeout: Request timeout in seconds
session_config: Session-specific configuration
"""
# Use session config if provided, otherwise fall back to global config
if session_config is not None:
self.config = session_config
actual_rate_limit = self.config.get_rate_limit(name)
actual_timeout = self.config.default_timeout
else:
# Fallback to global config for backwards compatibility
from config import config as global_config
self.config = global_config
actual_rate_limit = rate_limit
actual_timeout = timeout
self.name = name
self.rate_limiter = RateLimiter(rate_limit)
self.timeout = timeout
self.rate_limiter = RateLimiter(actual_rate_limit)
self.timeout = actual_timeout
self._local = threading.local()
self.logger = get_forensic_logger()
self._stop_event = None
# Caching configuration
self.cache_dir = '.cache'
# Caching configuration (per session)
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
self.cache_expiry = 12 * 3600 # 12 hours in seconds
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# Statistics
# Statistics (per provider instance)
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
@property
def session(self):
if not hasattr(self._local, 'session'):
@@ -118,136 +133,174 @@ class BaseProvider(ABC):
pass
def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with forensic logging and retry logic.
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
target_indicator: str = "",
max_retries: int = 3) -> Optional[requests.Response]:
"""
Make a rate-limited HTTP request with forensic logging and retry logic.
Now supports cancellation via stop_event from scanner.
"""
# Check for cancellation before starting
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled before start: {url}")
return None
Args:
url: Request URL
method: HTTP method
params: Query parameters
headers: Additional headers
target_indicator: The indicator being investigated
max_retries: Maximum number of retry attempts
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
Returns:
Response object or None if request failed
"""
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key)
# Check cache
if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry:
print(f"Returning cached response for: {url}")
with open(cache_path, 'r') as f:
cached_data = json.load(f)
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers = cached_data['headers']
return response
for attempt in range(max_retries + 1):
# Apply rate limiting
self.rate_limiter.wait_if_needed()
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
# Prepare request
request_headers = self.session.headers.copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=self.timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=self.timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
# Cache the successful response to disk
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
# Check cache
if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry:
print(f"Returning cached response for: {url}")
with open(cache_path, 'r') as f:
cached_data = json.load(f)
response = requests.Response()
response.status_code = cached_data['status_code']
response._content = cached_data['content'].encode('utf-8')
response.headers = cached_data['headers']
return response
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
print(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time)
continue
else:
break
for attempt in range(max_retries + 1):
# Check for cancellation before each attempt
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
# Apply rate limiting (but reduce wait time if cancellation is requested)
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
break
self.rate_limiter.wait_if_needed()
# Check again after rate limiting
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled after rate limiting: {url}")
return None
start_time = time.time()
response = None
error = None
try:
self.total_requests += 1
# Prepare request
request_headers = self.session.headers.copy()
if headers:
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# Use shorter timeout if termination is requested
request_timeout = self.timeout
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=request_timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=request_timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
# Cache the successful response to disk
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
return response
except requests.exceptions.RequestException as e:
error = str(e)
self.failed_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
# Check for cancellation before retrying
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled, not retrying: {url}")
break
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
print(f"Retrying in {backoff_time} seconds...")
# Shorter backoff if termination is requested
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
backoff_time = min(0.5, backoff_time)
# Sleep with cancellation checking
sleep_start = time.time()
while time.time() - sleep_start < backoff_time:
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
print(f"Request cancelled during backoff: {url}")
return None
time.sleep(0.1) # Check every 100ms
continue
else:
break
# All attempts failed - log and return None
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
return None
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
break
# All attempts failed - log and return None
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
return None
def set_stop_event(self, stop_event: threading.Event) -> None:
"""
Set the stop event for this provider to enable cancellation.
Args:
stop_event: Threading event to signal cancellation
"""
self._stop_event = stop_event
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
@@ -314,90 +367,4 @@ class BaseProvider(ABC):
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
'relationships_found': self.total_relationships_found,
'rate_limit': self.rate_limiter.requests_per_minute
}
def reset_statistics(self) -> None:
"""Reset provider statistics."""
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.total_relationships_found = 0
def _extract_domain_from_url(self, url: str) -> Optional[str]:
"""
Extract domain from URL.
Args:
url: URL string
Returns:
Domain name or None if extraction fails
"""
try:
# Remove protocol
if '://' in url:
url = url.split('://', 1)[1]
# Remove path
if '/' in url:
url = url.split('/', 1)[0]
# Remove port
if ':' in url:
url = url.split(':', 1)[0]
return url.lower()
except Exception:
return None
def _is_valid_domain(self, domain: str) -> bool:
"""
Basic domain validation.
Args:
domain: Domain string to validate
Returns:
True if domain appears valid
"""
if not domain or len(domain) > 253:
return False
# Check for valid characters and structure
parts = domain.split('.')
if len(parts) < 2:
return False
for part in parts:
if not part or len(part) > 63:
return False
if not part.replace('-', '').replace('_', '').isalnum():
return False
return True
def _is_valid_ip(self, ip: str) -> bool:
"""
Basic IP address validation.
Args:
ip: IP address string to validate
Returns:
True if IP appears valid
"""
try:
parts = ip.split('.')
if len(parts) != 4:
return False
for part in parts:
num = int(part)
if not 0 <= num <= 255:
return False
return True
except (ValueError, AttributeError):
return False
}

View File

@@ -1,6 +1,7 @@
"""
Certificate Transparency provider using crt.sh.
Discovers domain relationships through certificate SAN analysis.
Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking.
Stores certificates as metadata on domain nodes rather than creating certificate nodes.
"""
import json
@@ -10,23 +11,26 @@ from urllib.parse import quote
from datetime import datetime, timezone
from .base_provider import BaseProvider
from utils.helpers import _is_valid_domain
from core.graph_manager import RelationshipType
class CrtShProvider(BaseProvider):
"""
Provider for querying crt.sh certificate transparency database.
Discovers domain relationships through certificate Subject Alternative Names (SANs).
Now uses session-specific configuration and caching.
"""
def __init__(self):
"""Initialize CrtSh provider with appropriate rate limiting."""
def __init__(self, session_config=None):
"""Initialize CrtSh provider with session-specific configuration."""
super().__init__(
name="crtsh",
rate_limit=60, # Be respectful to the free service
timeout=30
rate_limit=60,
timeout=15,
session_config=session_config
)
self.base_url = "https://crt.sh/"
self._stop_event = None
def get_name(self) -> str:
"""Return the provider name."""
@@ -40,31 +44,128 @@ class CrtShProvider(BaseProvider):
"""
return True
def _parse_certificate_date(self, date_string: str) -> datetime:
"""
Parse certificate date from crt.sh format.
Args:
date_string: Date string from crt.sh API
Returns:
Parsed datetime object in UTC
"""
if not date_string:
raise ValueError("Empty date string")
try:
# Handle various possible formats from crt.sh
if date_string.endswith('Z'):
return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc)
elif '+' in date_string or date_string.endswith('UTC'):
# Handle timezone-aware strings
date_string = date_string.replace('UTC', '').strip()
if '+' in date_string:
date_string = date_string.split('+')[0]
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
else:
# Assume UTC if no timezone specified
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
except Exception as e:
# Fallback: try parsing without timezone info and assume UTC
try:
return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc)
except Exception:
raise ValueError(f"Unable to parse date: {date_string}") from e
def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool:
"""Check if a certificate is currently valid."""
"""
Check if a certificate is currently valid based on its expiry date.
Args:
cert_data: Certificate data from crt.sh
Returns:
True if certificate is currently valid (not expired)
"""
try:
not_after_str = cert_data.get('not_after')
if not_after_str:
# Append 'Z' to indicate UTC if it's not present
if not not_after_str.endswith('Z'):
not_after_str += 'Z'
not_after_date = datetime.fromisoformat(not_after_str.replace('Z', '+00:00'))
return not_after_date > datetime.now(timezone.utc)
except Exception:
if not not_after_str:
return False
not_after_date = self._parse_certificate_date(not_after_str)
not_before_str = cert_data.get('not_before')
now = datetime.now(timezone.utc)
# Check if certificate is within valid date range
is_not_expired = not_after_date > now
if not_before_str:
not_before_date = self._parse_certificate_date(not_before_str)
is_not_before_valid = not_before_date <= now
return is_not_expired and is_not_before_valid
return is_not_expired
except Exception as e:
self.logger.logger.debug(f"Certificate validity check failed: {e}")
return False
return False
def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract comprehensive metadata from certificate data.
Args:
cert_data: Raw certificate data from crt.sh
Returns:
Comprehensive certificate metadata dictionary
"""
metadata = {
'certificate_id': cert_data.get('id'),
'serial_number': cert_data.get('serial_number'),
'issuer_name': cert_data.get('issuer_name'),
'issuer_ca_id': cert_data.get('issuer_ca_id'),
'common_name': cert_data.get('common_name'),
'not_before': cert_data.get('not_before'),
'not_after': cert_data.get('not_after'),
'entry_timestamp': cert_data.get('entry_timestamp'),
'source': 'crt.sh'
}
# Add computed fields
try:
if metadata['not_before'] and metadata['not_after']:
not_before = self._parse_certificate_date(metadata['not_before'])
not_after = self._parse_certificate_date(metadata['not_after'])
metadata['validity_period_days'] = (not_after - not_before).days
metadata['is_currently_valid'] = self._is_cert_valid(cert_data)
metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30
# Add human-readable dates
metadata['not_before_formatted'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC')
metadata['not_after_formatted'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC')
except Exception as e:
self.logger.logger.debug(f"Error computing certificate metadata: {e}")
metadata['is_currently_valid'] = False
metadata['expires_soon'] = False
return metadata
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the domain.
Args:
domain: Domain to investigate
Returns:
List of relationships discovered from certificate analysis
Creates domain-to-domain relationships and stores certificate data as metadata.
Now supports early termination via stop_event.
"""
if not self._is_valid_domain(domain):
if not _is_valid_domain(domain):
return []
# Check for cancellation before starting
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before start for domain: {domain}")
return []
relationships = []
@@ -72,56 +173,113 @@ class CrtShProvider(BaseProvider):
try:
# Query crt.sh for certificates
url = f"{self.base_url}?q={quote(domain)}&output=json"
response = self.make_request(url, target_indicator=domain)
response = self.make_request(url, target_indicator=domain, max_retries=1) # Reduce retries for faster cancellation
if not response or response.status_code != 200:
return []
# Check for cancellation after request
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled after request for domain: {domain}")
return []
certificates = response.json()
if not certificates:
return []
# Process certificates to extract relationships
discovered_subdomains = {}
# Check for cancellation before processing
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before processing for domain: {domain}")
return []
for cert_data in certificates:
# Aggregate certificate data by domain
domain_certificates = {}
all_discovered_domains = set()
# Process certificates and group by domain (with cancellation checks)
for i, cert_data in enumerate(certificates):
# Check for cancellation every 10 certificates
if i % 10 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}")
break
cert_metadata = self._extract_certificate_metadata(cert_data)
cert_domains = self._extract_domains_from_certificate(cert_data)
is_valid = self._is_cert_valid(cert_data)
# Add all domains from this certificate to our tracking
for cert_domain in cert_domains:
if not _is_valid_domain(cert_domain):
continue
all_discovered_domains.add(cert_domain)
# Initialize domain certificate list if needed
if cert_domain not in domain_certificates:
domain_certificates[cert_domain] = []
# Add this certificate to the domain's certificate list
domain_certificates[cert_domain].append(cert_metadata)
# Final cancellation check before creating relationships
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before relationship creation for domain: {domain}")
return []
for subdomain in cert_domains:
if self._is_valid_domain(subdomain) and subdomain != domain:
if subdomain not in discovered_subdomains:
discovered_subdomains[subdomain] = {'has_valid_cert': False, 'issuers': set()}
if is_valid:
discovered_subdomains[subdomain]['has_valid_cert'] = True
issuer = cert_data.get('issuer_name')
if issuer:
discovered_subdomains[subdomain]['issuers'].add(issuer)
# Create relationships from query domain to ALL discovered domains
for discovered_domain in all_discovered_domains:
if discovered_domain == domain:
continue # Skip self-relationships
# Check for cancellation during relationship creation
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh relationship creation cancelled for domain: {domain}")
break
# Create relationships from the discovered subdomains
for subdomain, data in discovered_subdomains.items():
raw_data = {
'has_valid_cert': data['has_valid_cert'],
'issuers': list(data['issuers']),
'source': 'crt.sh'
if not _is_valid_domain(discovered_domain):
continue
# Get certificates for both domains
query_domain_certs = domain_certificates.get(domain, [])
discovered_domain_certs = domain_certificates.get(discovered_domain, [])
# Find shared certificates (for metadata purposes)
shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs)
# Calculate confidence based on relationship type and shared certificates
confidence = self._calculate_domain_relationship_confidence(
domain, discovered_domain, shared_certificates, all_discovered_domains
)
# Create comprehensive raw data for the relationship
relationship_raw_data = {
'relationship_type': 'certificate_discovery',
'shared_certificates': shared_certificates,
'total_shared_certs': len(shared_certificates),
'discovery_context': self._determine_relationship_context(discovered_domain, domain),
'domain_certificates': {
domain: self._summarize_certificates(query_domain_certs),
discovered_domain: self._summarize_certificates(discovered_domain_certs)
}
}
# Create domain -> domain relationship
relationships.append((
domain,
subdomain,
discovered_domain,
RelationshipType.SAN_CERTIFICATE,
RelationshipType.SAN_CERTIFICATE.default_confidence,
raw_data
confidence,
relationship_raw_data
))
# Log the relationship discovery
self.log_relationship_discovery(
source_node=domain,
target_node=subdomain,
target_node=discovered_domain,
relationship_type=RelationshipType.SAN_CERTIFICATE,
confidence_score=RelationshipType.SAN_CERTIFICATE.default_confidence,
raw_data=raw_data,
discovery_method="certificate_san_analysis"
confidence_score=confidence,
raw_data=relationship_raw_data,
discovery_method="certificate_transparency_analysis"
)
except json.JSONDecodeError as e:
@@ -130,6 +288,165 @@ class CrtShProvider(BaseProvider):
self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}")
return relationships
def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Find certificates that are shared between two domain certificate lists.
Args:
certs1: First domain's certificates
certs2: Second domain's certificates
Returns:
List of shared certificate metadata
"""
shared = []
# Create a set of certificate IDs from the first list for quick lookup
cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')}
# Find certificates in the second list that match
for cert in certs2:
if cert.get('certificate_id') in cert1_ids:
shared.append(cert)
return shared
def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Create a summary of certificates for a domain.
Args:
certificates: List of certificate metadata
Returns:
Summary dictionary with aggregate statistics
"""
if not certificates:
return {
'total_certificates': 0,
'valid_certificates': 0,
'expired_certificates': 0,
'expires_soon_count': 0,
'unique_issuers': [],
'latest_certificate': None,
'has_valid_cert': False
}
valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid'))
expired_count = len(certificates) - valid_count
expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon'))
# Get unique issuers
unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name')))
# Find the most recent certificate
latest_cert = None
latest_date = None
for cert in certificates:
try:
if cert.get('not_before'):
cert_date = self._parse_certificate_date(cert['not_before'])
if latest_date is None or cert_date > latest_date:
latest_date = cert_date
latest_cert = cert
except Exception:
continue
return {
'total_certificates': len(certificates),
'valid_certificates': valid_count,
'expired_certificates': expired_count,
'expires_soon_count': expires_soon_count,
'unique_issuers': unique_issuers,
'latest_certificate': latest_cert,
'has_valid_cert': valid_count > 0,
'certificate_details': certificates # Full details for forensic analysis
}
def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str,
shared_certificates: List[Dict[str, Any]],
all_discovered_domains: Set[str]) -> float:
"""
Calculate confidence score for domain relationship based on various factors.
Args:
domain1: Source domain (query domain)
domain2: Target domain (discovered domain)
shared_certificates: List of shared certificate metadata
all_discovered_domains: All domains discovered in this query
Returns:
Confidence score between 0.0 and 1.0
"""
base_confidence = RelationshipType.SAN_CERTIFICATE.default_confidence
# Adjust confidence based on domain relationship context
relationship_context = self._determine_relationship_context(domain2, domain1)
if relationship_context == 'exact_match':
context_bonus = 0.0 # This shouldn't happen, but just in case
elif relationship_context == 'subdomain':
context_bonus = 0.1 # High confidence for subdomains
elif relationship_context == 'parent_domain':
context_bonus = 0.05 # Medium confidence for parent domains
else:
context_bonus = 0.0 # Related domains get base confidence
# Adjust confidence based on shared certificates
if shared_certificates:
shared_count = len(shared_certificates)
if shared_count >= 3:
shared_bonus = 0.1
elif shared_count >= 2:
shared_bonus = 0.05
else:
shared_bonus = 0.02
# Additional bonus for valid shared certificates
valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid'))
if valid_shared > 0:
validity_bonus = 0.05
else:
validity_bonus = 0.0
else:
# Even without shared certificates, domains found in the same query have some relationship
shared_bonus = 0.0
validity_bonus = 0.0
# Adjust confidence based on certificate issuer reputation (if shared certificates exist)
issuer_bonus = 0.0
if shared_certificates:
for cert in shared_certificates:
issuer = cert.get('issuer_name', '').lower()
if any(trusted_ca in issuer for trusted_ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']):
issuer_bonus = max(issuer_bonus, 0.03)
break
# Calculate final confidence
final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus
return max(0.1, min(1.0, final_confidence)) # Clamp between 0.1 and 1.0
def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str:
"""
Determine the context of the relationship between certificate domain and query domain.
Args:
cert_domain: Domain found in certificate
query_domain: Original query domain
Returns:
String describing the relationship context
"""
if cert_domain == query_domain:
return 'exact_match'
elif cert_domain.endswith(f'.{query_domain}'):
return 'subdomain'
elif query_domain.endswith(f'.{cert_domain}'):
return 'parent_domain'
else:
return 'related_domain'
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
@@ -143,7 +460,6 @@ class CrtShProvider(BaseProvider):
Empty list (crt.sh doesn't support IP-based certificate queries effectively)
"""
# crt.sh doesn't effectively support IP-based certificate queries
# This would require parsing certificate details for IP SANs, which is complex
return []
def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]:
@@ -162,7 +478,7 @@ class CrtShProvider(BaseProvider):
common_name = cert_data.get('common_name', '')
if common_name:
cleaned_cn = self._clean_domain_name(common_name)
if cleaned_cn and self._is_valid_domain(cleaned_cn):
if cleaned_cn and _is_valid_domain(cleaned_cn):
domains.add(cleaned_cn)
# Extract from name_value field (contains SANs)
@@ -171,7 +487,7 @@ class CrtShProvider(BaseProvider):
# Split by newlines and clean each domain
for line in name_value.split('\n'):
cleaned_domain = self._clean_domain_name(line.strip())
if cleaned_domain and self._is_valid_domain(cleaned_domain):
if cleaned_domain and _is_valid_domain(cleaned_domain):
domains.add(cleaned_domain)
return domains
@@ -215,70 +531,4 @@ class CrtShProvider(BaseProvider):
if domain and not domain.startswith(('.', '-')) and not domain.endswith(('.', '-')):
return domain
return ""
def get_certificate_details(self, certificate_id: str) -> Dict[str, Any]:
"""
Get detailed information about a specific certificate.
Args:
certificate_id: Certificate ID from crt.sh
Returns:
Dictionary containing certificate details
"""
try:
url = f"{self.base_url}?id={certificate_id}&output=json"
response = self.make_request(url, target_indicator=f"cert_{certificate_id}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error fetching certificate details for {certificate_id}: {e}")
return {}
def search_certificates_by_serial(self, serial_number: str) -> List[Dict[str, Any]]:
"""
Search for certificates by serial number.
Args:
serial_number: Certificate serial number
Returns:
List of matching certificates
"""
try:
url = f"{self.base_url}?serial={quote(serial_number)}&output=json"
response = self.make_request(url, target_indicator=f"serial_{serial_number}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error searching certificates by serial {serial_number}: {e}")
return []
def get_issuer_certificates(self, issuer_name: str) -> List[Dict[str, Any]]:
"""
Get certificates issued by a specific CA.
Args:
issuer_name: Certificate Authority name
Returns:
List of certificates from the specified issuer
"""
try:
url = f"{self.base_url}?issuer={quote(issuer_name)}&output=json"
response = self.make_request(url, target_indicator=f"issuer_{issuer_name}")
if response and response.status_code == 200:
return response.json()
except Exception as e:
self.logger.logger.error(f"Error fetching certificates for issuer {issuer_name}: {e}")
return []
return ""

View File

@@ -1,25 +1,26 @@
# dnsrecon/providers/dns_provider.py
import socket
import dns.resolver
import dns.reversename
from typing import List, Dict, Any, Tuple, Optional
from typing import List, Dict, Any, Tuple
from .base_provider import BaseProvider
from core.graph_manager import RelationshipType, NodeType
from utils.helpers import _is_valid_ip, _is_valid_domain
from core.graph_manager import RelationshipType
class DNSProvider(BaseProvider):
"""
Provider for standard DNS resolution and reverse DNS lookups.
Discovers domain-to-IP and IP-to-domain relationships through DNS records.
Now uses session-specific configuration.
"""
def __init__(self):
"""Initialize DNS provider with appropriate rate limiting."""
def __init__(self, session_config=None):
"""Initialize DNS provider with session-specific configuration."""
super().__init__(
name="dns",
rate_limit=100, # DNS queries can be faster
timeout=10
rate_limit=100,
timeout=10,
session_config=session_config
)
# Configure DNS resolver
@@ -45,7 +46,7 @@ class DNSProvider(BaseProvider):
Returns:
List of relationships discovered from DNS analysis
"""
if not self._is_valid_domain(domain):
if not _is_valid_domain(domain):
return []
relationships = []
@@ -66,7 +67,7 @@ class DNSProvider(BaseProvider):
Returns:
List of relationships discovered from reverse DNS
"""
if not self._is_valid_ip(ip):
if not _is_valid_ip(ip):
return []
relationships = []
@@ -81,7 +82,7 @@ class DNSProvider(BaseProvider):
for ptr_record in response:
hostname = str(ptr_record).rstrip('.')
if self._is_valid_domain(hostname):
if _is_valid_domain(hostname):
raw_data = {
'query_type': 'PTR',
'ip_address': ip,

View File

@@ -4,38 +4,37 @@ Discovers IP relationships and infrastructure context through Shodan API.
"""
import json
from typing import List, Dict, Any, Tuple, Optional
from urllib.parse import quote
from typing import List, Dict, Any, Tuple
from .base_provider import BaseProvider
from utils.helpers import _is_valid_ip, _is_valid_domain
from core.graph_manager import RelationshipType
from config import config
class ShodanProvider(BaseProvider):
"""
Provider for querying Shodan API for IP address and hostname information.
Requires valid API key and respects Shodan's rate limits.
Now uses session-specific API keys.
"""
def __init__(self):
"""Initialize Shodan provider with appropriate rate limiting."""
def __init__(self, session_config=None):
"""Initialize Shodan provider with session-specific configuration."""
super().__init__(
name="shodan",
rate_limit=60, # Shodan API has various rate limits depending on plan
timeout=30
rate_limit=60,
timeout=30,
session_config=session_config
)
self.base_url = "https://api.shodan.io"
self.api_key = config.get_api_key('shodan')
self.api_key = self.config.get_api_key('shodan')
def is_available(self) -> bool:
"""Check if Shodan provider is available (has valid API key in this session)."""
return self.api_key is not None and len(self.api_key.strip()) > 0
def get_name(self) -> str:
"""Return the provider name."""
return "shodan"
def is_available(self) -> bool:
"""
Check if Shodan provider is available (has valid API key).
"""
return self.api_key is not None and len(self.api_key.strip()) > 0
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
@@ -48,7 +47,7 @@ class ShodanProvider(BaseProvider):
Returns:
List of relationships discovered from Shodan data
"""
if not self._is_valid_domain(domain) or not self.is_available():
if not _is_valid_domain(domain) or not self.is_available():
return []
relationships = []
@@ -109,7 +108,7 @@ class ShodanProvider(BaseProvider):
# Also create relationships to other hostnames on the same IP
for hostname in hostnames:
if hostname != domain and self._is_valid_domain(hostname):
if hostname != domain and _is_valid_domain(hostname):
hostname_raw_data = {
'shared_ip': ip_address,
'all_hostnames': hostnames,
@@ -150,7 +149,7 @@ class ShodanProvider(BaseProvider):
Returns:
List of relationships discovered from Shodan IP data
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return []
relationships = []
@@ -170,7 +169,7 @@ class ShodanProvider(BaseProvider):
# Extract hostname relationships
hostnames = data.get('hostnames', [])
for hostname in hostnames:
if self._is_valid_domain(hostname):
if _is_valid_domain(hostname):
raw_data = {
'ip_address': ip,
'hostname': hostname,
@@ -280,7 +279,7 @@ class ShodanProvider(BaseProvider):
Returns:
List of service information dictionaries
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return []
try:

View File

@@ -4,38 +4,37 @@ Discovers domain relationships through passive DNS and URL analysis.
"""
import json
from typing import List, Dict, Any, Tuple, Optional
from typing import List, Dict, Any, Tuple
from .base_provider import BaseProvider
from utils.helpers import _is_valid_ip, _is_valid_domain
from core.graph_manager import RelationshipType
from config import config
class VirusTotalProvider(BaseProvider):
"""
Provider for querying VirusTotal API for passive DNS and domain reputation data.
Requires valid API key and strictly respects free tier rate limits.
Now uses session-specific API keys and rate limits.
"""
def __init__(self):
"""Initialize VirusTotal provider with strict rate limiting for free tier."""
def __init__(self, session_config=None):
"""Initialize VirusTotal provider with session-specific configuration."""
super().__init__(
name="virustotal",
rate_limit=4, # Free tier: 4 requests per minute
timeout=30
timeout=30,
session_config=session_config
)
self.base_url = "https://www.virustotal.com/vtapi/v2"
self.api_key = config.get_api_key('virustotal')
self.api_key = self.config.get_api_key('virustotal')
def is_available(self) -> bool:
"""Check if VirusTotal provider is available (has valid API key in this session)."""
return self.api_key is not None and len(self.api_key.strip()) > 0
def get_name(self) -> str:
"""Return the provider name."""
return "virustotal"
def is_available(self) -> bool:
"""
Check if VirusTotal provider is available (has valid API key).
"""
return self.api_key is not None and len(self.api_key.strip()) > 0
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
"""
Query VirusTotal for domain information including passive DNS.
@@ -46,7 +45,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
List of relationships discovered from VirusTotal data
"""
if not self._is_valid_domain(domain) or not self.is_available():
if not _is_valid_domain(domain) or not self.is_available():
return []
relationships = []
@@ -71,7 +70,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
List of relationships discovered from VirusTotal IP data
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return []
relationships = []
@@ -114,7 +113,7 @@ class VirusTotalProvider(BaseProvider):
ip_address = resolution.get('ip_address')
last_resolved = resolution.get('last_resolved')
if ip_address and self._is_valid_ip(ip_address):
if ip_address and _is_valid_ip(ip_address):
raw_data = {
'domain': domain,
'ip_address': ip_address,
@@ -142,7 +141,7 @@ class VirusTotalProvider(BaseProvider):
# Extract subdomains
subdomains = data.get('subdomains', [])
for subdomain in subdomains:
if subdomain != domain and self._is_valid_domain(subdomain):
if subdomain != domain and _is_valid_domain(subdomain):
raw_data = {
'parent_domain': domain,
'subdomain': subdomain,
@@ -200,7 +199,7 @@ class VirusTotalProvider(BaseProvider):
hostname = resolution.get('hostname')
last_resolved = resolution.get('last_resolved')
if hostname and self._is_valid_domain(hostname):
if hostname and _is_valid_domain(hostname):
raw_data = {
'ip_address': ip,
'hostname': hostname,
@@ -254,7 +253,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
Dictionary containing reputation data
"""
if not self._is_valid_domain(domain) or not self.is_available():
if not _is_valid_domain(domain) or not self.is_available():
return {}
try:
@@ -293,7 +292,7 @@ class VirusTotalProvider(BaseProvider):
Returns:
Dictionary containing reputation data
"""
if not self._is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip) or not self.is_available():
return {}
try: