This commit is contained in:
overcuriousity 2025-09-14 22:54:37 +02:00
parent eb9eea127b
commit 9f3b17e658
2 changed files with 249 additions and 479 deletions

View File

@ -5,32 +5,46 @@ import re
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Tuple, Set from typing import List, Dict, Any, Tuple, Set
from urllib.parse import quote
from datetime import datetime, timezone from datetime import datetime, timezone
import requests
# New dependency required for this provider
try:
import psycopg2
import psycopg2.extras
PSYCOPG2_AVAILABLE = True
except ImportError:
PSYCOPG2_AVAILABLE = False
from .base_provider import BaseProvider from .base_provider import BaseProvider
from utils.helpers import _is_valid_domain from utils.helpers import _is_valid_domain
# We use requests only to raise the same exception type for compatibility with core retry logic
import requests
class CrtShProvider(BaseProvider): class CrtShProvider(BaseProvider):
""" """
Provider for querying crt.sh certificate transparency database. Provider for querying crt.sh certificate transparency database via its public PostgreSQL endpoint.
Now uses session-specific configuration and caching with accumulative behavior. This version is designed to be a drop-in, high-performance replacement for the API-based provider.
It preserves the same caching and data processing logic.
""" """
def __init__(self, name=None, session_config=None): def __init__(self, name=None, session_config=None):
"""Initialize CrtSh provider with session-specific configuration.""" """Initialize CrtShDB provider with session-specific configuration."""
super().__init__( super().__init__(
name="crtsh", name="crtsh",
rate_limit=60, rate_limit=0, # No rate limit for direct DB access
timeout=15, timeout=60, # Increased timeout for potentially long DB queries
session_config=session_config session_config=session_config
) )
self.base_url = "https://crt.sh/" # Database connection details
self.db_host = "crt.sh"
self.db_port = 5432
self.db_name = "certwatch"
self.db_user = "guest"
self._stop_event = None self._stop_event = None
# Initialize cache directory # Initialize cache directory (same as original provider)
self.cache_dir = Path('cache') / 'crtsh' self.cache_dir = Path('cache') / 'crtsh'
self.cache_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True)
@ -40,7 +54,7 @@ class CrtShProvider(BaseProvider):
def get_display_name(self) -> str: def get_display_name(self) -> str:
"""Return the provider display name for the UI.""" """Return the provider display name for the UI."""
return "crt.sh" return "crt.sh (DB)"
def requires_api_key(self) -> bool: def requires_api_key(self) -> bool:
"""Return True if the provider requires an API key.""" """Return True if the provider requires an API key."""
@ -52,23 +66,161 @@ class CrtShProvider(BaseProvider):
def is_available(self) -> bool: def is_available(self) -> bool:
""" """
Check if the provider is configured to be used. Check if the provider can be used. Requires the psycopg2 library.
This method is intentionally simple and does not perform a network request
to avoid blocking application startup.
""" """
if not PSYCOPG2_AVAILABLE:
self.logger.logger.warning("psycopg2 library not found. CrtShDBProvider is unavailable. "
"Please run 'pip install psycopg2-binary'.")
return False
return True return True
def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]:
"""
Query the crt.sh PostgreSQL database for raw certificate data.
Raises exceptions for DB/network errors to allow core logic to retry.
"""
conn = None
certificates = []
# SQL Query to find all certificate IDs related to the domain (including subdomains),
# then retrieve comprehensive details for each certificate, mimicking the JSON API structure.
sql_query = """
WITH certificates_of_interest AS (
SELECT DISTINCT ci.certificate_id
FROM certificate_identity ci
WHERE ci.name_value ILIKE %(domain_wildcard)s OR ci.name_value = %(domain)s
)
SELECT
c.id,
c.serial_number,
c.not_before,
c.not_after,
(SELECT min(entry_timestamp) FROM ct_log_entry cle WHERE cle.certificate_id = c.id) as entry_timestamp,
ca.id as issuer_ca_id,
ca.name as issuer_name,
(SELECT array_to_string(array_agg(DISTINCT ci.name_value), E'\n') FROM certificate_identity ci WHERE ci.certificate_id = c.id) as name_value,
(SELECT name_value FROM certificate_identity ci WHERE ci.certificate_id = c.id AND ci.name_type = 'commonName' LIMIT 1) as common_name
FROM
certificate c
JOIN ca ON c.issuer_ca_id = ca.id
WHERE c.id IN (SELECT certificate_id FROM certificates_of_interest);
"""
try:
conn = psycopg2.connect(
dbname=self.db_name,
user=self.db_user,
host=self.db_host,
port=self.db_port,
connect_timeout=self.timeout
)
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute(sql_query, {'domain': domain, 'domain_wildcard': f'%.{domain}'})
results = cursor.fetchall()
certificates = [dict(row) for row in results]
self.logger.logger.info(f"crt.sh DB query for '{domain}' returned {len(certificates)} certificates.")
except psycopg2.Error as e:
self.logger.logger.error(f"PostgreSQL query failed for {domain}: {e}")
# Raise a RequestException to be compatible with the existing retry logic in the core application
raise requests.exceptions.RequestException(f"PostgreSQL query failed: {e}") from e
finally:
if conn:
conn.close()
return certificates
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the domain with caching support.
Properly raises exceptions for network errors to allow core logic retries.
"""
if not _is_valid_domain(domain):
return []
if self._stop_event and self._stop_event.is_set():
return []
cache_file = self._get_cache_file_path(domain)
cache_status = self._get_cache_status(cache_file)
certificates = []
try:
if cache_status == "fresh":
certificates = self._load_cached_certificates(cache_file)
self.logger.logger.info(f"Using cached data for {domain} ({len(certificates)} certificates)")
elif cache_status == "not_found":
# Fresh query from DB, create new cache
certificates = self._query_crtsh(domain)
if certificates:
self._create_cache_file(cache_file, domain, self._serialize_certs_for_cache(certificates))
else:
self.logger.logger.info(f"No certificates found for {domain}, not caching")
elif cache_status == "stale":
try:
new_certificates = self._query_crtsh(domain)
if new_certificates:
certificates = self._append_to_cache(cache_file, self._serialize_certs_for_cache(new_certificates))
else:
certificates = self._load_cached_certificates(cache_file)
except requests.exceptions.RequestException:
certificates = self._load_cached_certificates(cache_file)
if certificates:
self.logger.logger.warning(f"DB query failed for {domain}, using stale cache data.")
else:
raise
except requests.exceptions.RequestException as e:
# Re-raise so core logic can retry
self.logger.logger.error(f"DB query failed for {domain}: {e}")
raise e
except json.JSONDecodeError as e:
# JSON parsing errors from cache should also be handled
self.logger.logger.error(f"Failed to parse JSON from cache for {domain}: {e}")
raise e
if self._stop_event and self._stop_event.is_set():
return []
if not certificates:
return []
return self._process_certificates_to_relationships(domain, certificates)
def _serialize_certs_for_cache(self, certificates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Serialize certificate data for JSON caching, converting datetime objects to ISO strings.
"""
serialized_certs = []
for cert in certificates:
serialized_cert = cert.copy()
for key in ['not_before', 'not_after', 'entry_timestamp']:
if isinstance(serialized_cert.get(key), datetime):
# Ensure datetime is timezone-aware before converting
dt_obj = serialized_cert[key]
if dt_obj.tzinfo is None:
dt_obj = dt_obj.replace(tzinfo=timezone.utc)
serialized_cert[key] = dt_obj.isoformat()
serialized_certs.append(serialized_cert)
return serialized_certs
# --- All methods below are copied directly from the original CrtShProvider ---
# They are compatible because _query_crtsh returns data in the same format
# as the original _query_crtsh_api method. A small adjustment is made to
# _parse_certificate_date to handle datetime objects directly from the DB.
def _get_cache_file_path(self, domain: str) -> Path: def _get_cache_file_path(self, domain: str) -> Path:
"""Generate cache file path for a domain.""" """Generate cache file path for a domain."""
# Sanitize domain for filename safety
safe_domain = domain.replace('.', '_').replace('/', '_').replace('\\', '_') safe_domain = domain.replace('.', '_').replace('/', '_').replace('\\', '_')
return self.cache_dir / f"{safe_domain}.json" return self.cache_dir / f"{safe_domain}.json"
def _get_cache_status(self, cache_file_path: Path) -> str: def _get_cache_status(self, cache_file_path: Path) -> str:
""" """Check cache status for a domain."""
Check cache status for a domain.
Returns: 'not_found', 'fresh', or 'stale'
"""
if not cache_file_path.exists(): if not cache_file_path.exists():
return "not_found" return "not_found"
@ -78,7 +230,7 @@ class CrtShProvider(BaseProvider):
last_query_str = cache_data.get("last_upstream_query") last_query_str = cache_data.get("last_upstream_query")
if not last_query_str: if not last_query_str:
return "stale" # Invalid cache format return "stale"
last_query = datetime.fromisoformat(last_query_str.replace('Z', '+00:00')) last_query = datetime.fromisoformat(last_query_str.replace('Z', '+00:00'))
hours_since_query = (datetime.now(timezone.utc) - last_query).total_seconds() / 3600 hours_since_query = (datetime.now(timezone.utc) - last_query).total_seconds() / 3600
@ -103,24 +255,6 @@ class CrtShProvider(BaseProvider):
self.logger.logger.error(f"Failed to load cached certificates from {cache_file_path}: {e}") self.logger.logger.error(f"Failed to load cached certificates from {cache_file_path}: {e}")
return [] return []
def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]:
"""
Query crt.sh API for raw certificate data.
Raises exceptions for network errors to allow core logic to retry.
"""
url = f"{self.base_url}?q={quote(domain)}&output=json"
response = self.make_request(url, target_indicator=domain)
if not response or response.status_code != 200:
# This could be a temporary error - raise exception so core can retry
raise requests.exceptions.RequestException(f"crt.sh API returned status {response.status_code if response else 'None'}")
certificates = response.json()
if not certificates:
return []
return certificates
def _create_cache_file(self, cache_file_path: Path, domain: str, certificates: List[Dict[str, Any]]) -> None: def _create_cache_file(self, cache_file_path: Path, domain: str, certificates: List[Dict[str, Any]]) -> None:
"""Create new cache file with certificates.""" """Create new cache file with certificates."""
try: try:
@ -131,27 +265,20 @@ class CrtShProvider(BaseProvider):
"upstream_query_count": 1, "upstream_query_count": 1,
"certificates": certificates "certificates": certificates
} }
cache_file_path.parent.mkdir(parents=True, exist_ok=True) cache_file_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_file_path, 'w') as f: with open(cache_file_path, 'w') as f:
json.dump(cache_data, f, separators=(',', ':')) json.dump(cache_data, f, separators=(',', ':'))
self.logger.logger.info(f"Created cache file for {domain} with {len(certificates)} certificates") self.logger.logger.info(f"Created cache file for {domain} with {len(certificates)} certificates")
except Exception as e: except Exception as e:
self.logger.logger.warning(f"Failed to create cache file for {domain}: {e}") self.logger.logger.warning(f"Failed to create cache file for {domain}: {e}")
def _append_to_cache(self, cache_file_path: Path, new_certificates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _append_to_cache(self, cache_file_path: Path, new_certificates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Append new certificates to existing cache and return all certificates.""" """Append new certificates to existing cache and return all certificates."""
try: try:
# Load existing cache
with open(cache_file_path, 'r') as f: with open(cache_file_path, 'r') as f:
cache_data = json.load(f) cache_data = json.load(f)
# Track existing certificate IDs to avoid duplicates
existing_ids = {cert.get('id') for cert in cache_data.get('certificates', [])} existing_ids = {cert.get('id') for cert in cache_data.get('certificates', [])}
# Add only new certificates
added_count = 0 added_count = 0
for cert in new_certificates: for cert in new_certificates:
cert_id = cert.get('id') cert_id = cert.get('id')
@ -160,314 +287,141 @@ class CrtShProvider(BaseProvider):
existing_ids.add(cert_id) existing_ids.add(cert_id)
added_count += 1 added_count += 1
# Update metadata
cache_data['last_upstream_query'] = datetime.now(timezone.utc).isoformat() cache_data['last_upstream_query'] = datetime.now(timezone.utc).isoformat()
cache_data['upstream_query_count'] = cache_data.get('upstream_query_count', 0) + 1 cache_data['upstream_query_count'] = cache_data.get('upstream_query_count', 0) + 1
# Write updated cache
with open(cache_file_path, 'w') as f: with open(cache_file_path, 'w') as f:
json.dump(cache_data, f, separators=(',', ':')) json.dump(cache_data, f, separators=(',', ':'))
total_certs = len(cache_data['certificates']) total_certs = len(cache_data['certificates'])
self.logger.logger.info(f"Appended {added_count} new certificates to cache. Total: {total_certs}") self.logger.logger.info(f"Appended {added_count} new certificates to cache. Total: {total_certs}")
return cache_data['certificates'] return cache_data['certificates']
except Exception as e: except Exception as e:
self.logger.logger.warning(f"Failed to append to cache: {e}") self.logger.logger.warning(f"Failed to append to cache: {e}")
return new_certificates # Fallback to new certificates only return new_certificates
def _parse_issuer_organization(self, issuer_dn: str) -> str: def _parse_issuer_organization(self, issuer_dn: str) -> str:
""" """Parse the issuer Distinguished Name to extract just the organization name."""
Parse the issuer Distinguished Name to extract just the organization name. if not issuer_dn: return issuer_dn
Args:
issuer_dn: Full issuer DN string (e.g., "C=US, O=Let's Encrypt, CN=R11")
Returns:
Organization name (e.g., "Let's Encrypt") or original string if parsing fails
"""
if not issuer_dn:
return issuer_dn
try: try:
# Split by comma and look for O= component
components = [comp.strip() for comp in issuer_dn.split(',')] components = [comp.strip() for comp in issuer_dn.split(',')]
for component in components: for component in components:
if component.startswith('O='): if component.startswith('O='):
# Extract the value after O=
org_name = component[2:].strip() org_name = component[2:].strip()
# Remove quotes if present
if org_name.startswith('"') and org_name.endswith('"'): if org_name.startswith('"') and org_name.endswith('"'):
org_name = org_name[1:-1] org_name = org_name[1:-1]
return org_name return org_name
# If no O= component found, return the original string
return issuer_dn return issuer_dn
except Exception as e: except Exception as e:
self.logger.logger.debug(f"Failed to parse issuer DN '{issuer_dn}': {e}") self.logger.logger.debug(f"Failed to parse issuer DN '{issuer_dn}': {e}")
return issuer_dn return issuer_dn
def _parse_certificate_date(self, date_string: str) -> datetime: def _parse_certificate_date(self, date_input: Any) -> datetime:
""" """
Parse certificate date from crt.sh format. Parse certificate date from various formats (string from cache, datetime from DB).
Args:
date_string: Date string from crt.sh API
Returns:
Parsed datetime object in UTC
""" """
if isinstance(date_input, datetime):
# If it's already a datetime object from the DB, just ensure it's UTC
if date_input.tzinfo is None:
return date_input.replace(tzinfo=timezone.utc)
return date_input
date_string = str(date_input)
if not date_string: if not date_string:
raise ValueError("Empty date string") raise ValueError("Empty date string")
try: try:
# Handle various possible formats from crt.sh if 'Z' in date_string:
if date_string.endswith('Z'): return datetime.fromisoformat(date_string.replace('Z', '+00:00'))
return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc) # Handle standard ISO format with or without timezone
elif '+' in date_string or date_string.endswith('UTC'): dt = datetime.fromisoformat(date_string)
# Handle timezone-aware strings if dt.tzinfo is None:
date_string = date_string.replace('UTC', '').strip() return dt.replace(tzinfo=timezone.utc)
if '+' in date_string: return dt
date_string = date_string.split('+')[0] except ValueError as e:
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
else:
# Assume UTC if no timezone specified
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
except Exception as e:
# Fallback: try parsing without timezone info and assume UTC
try: try:
# Fallback for other formats
return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc)
except Exception: except Exception:
raise ValueError(f"Unable to parse date: {date_string}") from e raise ValueError(f"Unable to parse date: {date_string}") from e
def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool: def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool:
""" """Check if a certificate is currently valid based on its expiry date."""
Check if a certificate is currently valid based on its expiry date.
Args:
cert_data: Certificate data from crt.sh
Returns:
True if certificate is currently valid (not expired)
"""
try: try:
not_after_str = cert_data.get('not_after') not_after_str = cert_data.get('not_after')
if not not_after_str: if not not_after_str: return False
return False
not_after_date = self._parse_certificate_date(not_after_str) not_after_date = self._parse_certificate_date(not_after_str)
not_before_str = cert_data.get('not_before') not_before_str = cert_data.get('not_before')
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
# Check if certificate is within valid date range
is_not_expired = not_after_date > now is_not_expired = not_after_date > now
if not_before_str: if not_before_str:
not_before_date = self._parse_certificate_date(not_before_str) not_before_date = self._parse_certificate_date(not_before_str)
is_not_before_valid = not_before_date <= now is_not_before_valid = not_before_date <= now
return is_not_expired and is_not_before_valid return is_not_expired and is_not_before_valid
return is_not_expired return is_not_expired
except Exception as e: except Exception as e:
self.logger.logger.debug(f"Certificate validity check failed: {e}") self.logger.logger.debug(f"Certificate validity check failed: {e}")
return False return False
def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]: def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]:
""" # This method works as-is.
Extract comprehensive metadata from certificate data.
Args:
cert_data: Raw certificate data from crt.sh
Returns:
Comprehensive certificate metadata dictionary
"""
# Parse the issuer name to get just the organization
raw_issuer_name = cert_data.get('issuer_name', '') raw_issuer_name = cert_data.get('issuer_name', '')
parsed_issuer_name = self._parse_issuer_organization(raw_issuer_name) parsed_issuer_name = self._parse_issuer_organization(raw_issuer_name)
metadata = { metadata = {
'certificate_id': cert_data.get('id'), 'certificate_id': cert_data.get('id'),
'serial_number': cert_data.get('serial_number'), 'serial_number': cert_data.get('serial_number'),
'issuer_name': parsed_issuer_name, # Use parsed organization name 'issuer_name': parsed_issuer_name,
#'issuer_name_full': raw_issuer_name, # deliberately left out, because its not useful in most cases
'issuer_ca_id': cert_data.get('issuer_ca_id'), 'issuer_ca_id': cert_data.get('issuer_ca_id'),
'common_name': cert_data.get('common_name'), 'common_name': cert_data.get('common_name'),
'not_before': cert_data.get('not_before'), 'not_before': cert_data.get('not_before'),
'not_after': cert_data.get('not_after'), 'not_after': cert_data.get('not_after'),
'entry_timestamp': cert_data.get('entry_timestamp'), 'entry_timestamp': cert_data.get('entry_timestamp'),
'source': 'crt.sh' 'source': 'crt.sh (DB)'
} }
try: try:
if metadata['not_before'] and metadata['not_after']: if metadata['not_before'] and metadata['not_after']:
not_before = self._parse_certificate_date(metadata['not_before']) not_before = self._parse_certificate_date(metadata['not_before'])
not_after = self._parse_certificate_date(metadata['not_after']) not_after = self._parse_certificate_date(metadata['not_after'])
metadata['validity_period_days'] = (not_after - not_before).days metadata['validity_period_days'] = (not_after - not_before).days
metadata['is_currently_valid'] = self._is_cert_valid(cert_data) metadata['is_currently_valid'] = self._is_cert_valid(cert_data)
metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30 metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30
# Add human-readable dates
metadata['not_before'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC') metadata['not_before'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC')
metadata['not_after'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC') metadata['not_after'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC')
except Exception as e: except Exception as e:
self.logger.logger.debug(f"Error computing certificate metadata: {e}") self.logger.logger.debug(f"Error computing certificate metadata: {e}")
metadata['is_currently_valid'] = False metadata['is_currently_valid'] = False
metadata['expires_soon'] = False metadata['expires_soon'] = False
return metadata return metadata
def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the domain with caching support.
Properly raises exceptions for network errors to allow core logic retries.
"""
if not _is_valid_domain(domain):
return []
# Check for cancellation before starting
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before start for domain: {domain}")
return []
# === CACHING LOGIC ===
cache_file = self._get_cache_file_path(domain)
cache_status = self._get_cache_status(cache_file)
certificates = []
try:
if cache_status == "fresh":
# Use cached data
certificates = self._load_cached_certificates(cache_file)
self.logger.logger.info(f"Using cached data for {domain} ({len(certificates)} certificates)")
elif cache_status == "not_found":
# Fresh query, create new cache
certificates = self._query_crtsh_api(domain)
if certificates: # Only cache if we got results
self._create_cache_file(cache_file, domain, certificates)
self.logger.logger.info(f"Cached fresh data for {domain} ({len(certificates)} certificates)")
else:
self.logger.logger.info(f"No certificates found for {domain}, not caching")
elif cache_status == "stale":
# Append query, update existing cache
try:
new_certificates = self._query_crtsh_api(domain)
if new_certificates:
certificates = self._append_to_cache(cache_file, new_certificates)
self.logger.logger.info(f"Refreshed and appended cache for {domain}")
else:
# Use existing cache if API returns no results
certificates = self._load_cached_certificates(cache_file)
self.logger.logger.info(f"API returned no new results, using existing cache for {domain}")
except requests.exceptions.RequestException:
# If API call fails for stale cache, use cached data and re-raise for retry logic
certificates = self._load_cached_certificates(cache_file)
if certificates:
self.logger.logger.warning(f"API call failed for {domain}, using stale cache data ({len(certificates)} certificates)")
# Don't re-raise here, just use cached data
else:
# No cached data and API failed - re-raise for retry
raise
except requests.exceptions.RequestException as e:
# Network/API errors should be re-raised so core logic can retry
self.logger.logger.error(f"API query failed for {domain}: {e}")
raise e
except json.JSONDecodeError as e:
# JSON parsing errors should also be raised for retry
self.logger.logger.error(f"Failed to parse JSON response from crt.sh for {domain}: {e}")
raise e
# Check for cancellation after cache operations
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled after cache operations for domain: {domain}")
return []
if not certificates:
return []
return self._process_certificates_to_relationships(domain, certificates)
def _process_certificates_to_relationships(self, domain: str, certificates: List[Dict[str, Any]]) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: def _process_certificates_to_relationships(self, domain: str, certificates: List[Dict[str, Any]]) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
""" # This method works as-is.
Process certificates to relationships using existing logic.
This method contains the original processing logic from query_domain.
"""
relationships = [] relationships = []
if self._stop_event and self._stop_event.is_set(): return []
# Check for cancellation before processing
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh processing cancelled before processing for domain: {domain}")
return []
# Aggregate certificate data by domain
domain_certificates = {} domain_certificates = {}
all_discovered_domains = set() all_discovered_domains = set()
# Process certificates with cancellation checking
for i, cert_data in enumerate(certificates): for i, cert_data in enumerate(certificates):
# Check for cancellation every 5 certificates for faster response if i % 5 == 0 and self._stop_event and self._stop_event.is_set(): break
if i % 5 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}")
break
cert_metadata = self._extract_certificate_metadata(cert_data) cert_metadata = self._extract_certificate_metadata(cert_data)
cert_domains = self._extract_domains_from_certificate(cert_data) cert_domains = self._extract_domains_from_certificate(cert_data)
# Add all domains from this certificate to our tracking
all_discovered_domains.update(cert_domains) all_discovered_domains.update(cert_domains)
for cert_domain in cert_domains: for cert_domain in cert_domains:
if not _is_valid_domain(cert_domain): if not _is_valid_domain(cert_domain): continue
continue
# Initialize domain certificate list if needed
if cert_domain not in domain_certificates: if cert_domain not in domain_certificates:
domain_certificates[cert_domain] = [] domain_certificates[cert_domain] = []
# Add this certificate to the domain's certificate list
domain_certificates[cert_domain].append(cert_metadata) domain_certificates[cert_domain].append(cert_metadata)
if self._stop_event and self._stop_event.is_set(): return []
# Final cancellation check before creating relationships
if self._stop_event and self._stop_event.is_set():
print(f"CrtSh query cancelled before relationship creation for domain: {domain}")
return []
# Create relationships from query domain to ALL discovered domains with stop checking
for i, discovered_domain in enumerate(all_discovered_domains): for i, discovered_domain in enumerate(all_discovered_domains):
if discovered_domain == domain: if discovered_domain == domain: continue
continue # Skip self-relationships if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): break
if not _is_valid_domain(discovered_domain): continue
# Check for cancellation every 10 relationships
if i % 10 == 0 and self._stop_event and self._stop_event.is_set():
print(f"CrtSh relationship creation cancelled for domain: {domain}")
break
if not _is_valid_domain(discovered_domain):
continue
# Get certificates for both domains
query_domain_certs = domain_certificates.get(domain, []) query_domain_certs = domain_certificates.get(domain, [])
discovered_domain_certs = domain_certificates.get(discovered_domain, []) discovered_domain_certs = domain_certificates.get(discovered_domain, [])
# Find shared certificates (for metadata purposes)
shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs) shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs)
# Calculate confidence based on relationship type and shared certificates
confidence = self._calculate_domain_relationship_confidence( confidence = self._calculate_domain_relationship_confidence(
domain, discovered_domain, shared_certificates, all_discovered_domains domain, discovered_domain, shared_certificates, all_discovered_domains
) )
# Create comprehensive raw data for the relationship
relationship_raw_data = { relationship_raw_data = {
'relationship_type': 'certificate_discovery', 'relationship_type': 'certificate_discovery',
'shared_certificates': shared_certificates, 'shared_certificates': shared_certificates,
@ -478,267 +432,82 @@ class CrtShProvider(BaseProvider):
discovered_domain: self._summarize_certificates(discovered_domain_certs) discovered_domain: self._summarize_certificates(discovered_domain_certs)
} }
} }
# Create domain -> domain relationship
relationships.append(( relationships.append((
domain, domain, discovered_domain, 'san_certificate', confidence, relationship_raw_data
discovered_domain,
'san_certificate',
confidence,
relationship_raw_data
)) ))
# Log the relationship discovery
self.log_relationship_discovery( self.log_relationship_discovery(
source_node=domain, source_node=domain, target_node=discovered_domain, relationship_type='san_certificate',
target_node=discovered_domain, confidence_score=confidence, raw_data=relationship_raw_data,
relationship_type='san_certificate',
confidence_score=confidence,
raw_data=relationship_raw_data,
discovery_method="certificate_transparency_analysis" discovery_method="certificate_transparency_analysis"
) )
return relationships return relationships
# --- All remaining helper methods are identical to the original and fully compatible ---
# They are included here for completeness.
def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Find certificates that are shared between two domain certificate lists.
Args:
certs1: First domain's certificates
certs2: Second domain's certificates
Returns:
List of shared certificate metadata
"""
shared = []
# Create a set of certificate IDs from the first list for quick lookup
cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')} cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')}
return [cert for cert in certs2 if cert.get('certificate_id') in cert1_ids]
# Find certificates in the second list that match
for cert in certs2:
if cert.get('certificate_id') in cert1_ids:
shared.append(cert)
return shared
def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]: def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]:
""" if not certificates: return {'total_certificates': 0, 'valid_certificates': 0, 'expired_certificates': 0, 'expires_soon_count': 0, 'unique_issuers': [], 'latest_certificate': None, 'has_valid_cert': False}
Create a summary of certificates for a domain.
Args:
certificates: List of certificate metadata
Returns:
Summary dictionary with aggregate statistics
"""
if not certificates:
return {
'total_certificates': 0,
'valid_certificates': 0,
'expired_certificates': 0,
'expires_soon_count': 0,
'unique_issuers': [],
'latest_certificate': None,
'has_valid_cert': False
}
valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid')) valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid'))
expired_count = len(certificates) - valid_count
expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon')) expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon'))
# Get unique issuers (using parsed organization names)
unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name'))) unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name')))
latest_cert, latest_date = None, None
# Find the most recent certificate
latest_cert = None
latest_date = None
for cert in certificates: for cert in certificates:
try: try:
if cert.get('not_before'): if cert.get('not_before'):
cert_date = self._parse_certificate_date(cert['not_before']) cert_date = self._parse_certificate_date(cert['not_before'])
if latest_date is None or cert_date > latest_date: if latest_date is None or cert_date > latest_date:
latest_date = cert_date latest_date, latest_cert = cert_date, cert
latest_cert = cert except Exception: continue
except Exception: return {'total_certificates': len(certificates), 'valid_certificates': valid_count, 'expired_certificates': len(certificates) - valid_count, 'expires_soon_count': expires_soon_count, 'unique_issuers': unique_issuers, 'latest_certificate': latest_cert, 'has_valid_cert': valid_count > 0, 'certificate_details': certificates}
continue
return { def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str, shared_certificates: List[Dict[str, Any]], all_discovered_domains: Set[str]) -> float:
'total_certificates': len(certificates), base_confidence, context_bonus, shared_bonus, validity_bonus, issuer_bonus = 0.9, 0.0, 0.0, 0.0, 0.0
'valid_certificates': valid_count,
'expired_certificates': expired_count,
'expires_soon_count': expires_soon_count,
'unique_issuers': unique_issuers,
'latest_certificate': latest_cert,
'has_valid_cert': valid_count > 0,
'certificate_details': certificates # Full details for forensic analysis
}
def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str,
shared_certificates: List[Dict[str, Any]],
all_discovered_domains: Set[str]) -> float:
"""
Calculate confidence score for domain relationship based on various factors.
Args:
domain1: Source domain (query domain)
domain2: Target domain (discovered domain)
shared_certificates: List of shared certificate metadata
all_discovered_domains: All domains discovered in this query
Returns:
Confidence score between 0.0 and 1.0
"""
base_confidence = 0.9
# Adjust confidence based on domain relationship context
relationship_context = self._determine_relationship_context(domain2, domain1) relationship_context = self._determine_relationship_context(domain2, domain1)
if relationship_context == 'subdomain': context_bonus = 0.1
if relationship_context == 'exact_match': elif relationship_context == 'parent_domain': context_bonus = 0.05
context_bonus = 0.0 # This shouldn't happen, but just in case
elif relationship_context == 'subdomain':
context_bonus = 0.1 # High confidence for subdomains
elif relationship_context == 'parent_domain':
context_bonus = 0.05 # Medium confidence for parent domains
else:
context_bonus = 0.0 # Related domains get base confidence
# Adjust confidence based on shared certificates
if shared_certificates:
shared_count = len(shared_certificates)
if shared_count >= 3:
shared_bonus = 0.1
elif shared_count >= 2:
shared_bonus = 0.05
else:
shared_bonus = 0.02
# Additional bonus for valid shared certificates
valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid'))
if valid_shared > 0:
validity_bonus = 0.05
else:
validity_bonus = 0.0
else:
# Even without shared certificates, domains found in the same query have some relationship
shared_bonus = 0.0
validity_bonus = 0.0
# Adjust confidence based on certificate issuer reputation (if shared certificates exist)
issuer_bonus = 0.0
if shared_certificates: if shared_certificates:
if len(shared_certificates) >= 3: shared_bonus = 0.1
elif len(shared_certificates) >= 2: shared_bonus = 0.05
else: shared_bonus = 0.02
if any(cert.get('is_currently_valid') for cert in shared_certificates): validity_bonus = 0.05
for cert in shared_certificates: for cert in shared_certificates:
issuer = cert.get('issuer_name', '').lower() if any(ca in cert.get('issuer_name', '').lower() for ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']):
if any(trusted_ca in issuer for trusted_ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']):
issuer_bonus = max(issuer_bonus, 0.03) issuer_bonus = max(issuer_bonus, 0.03)
break break
return max(0.1, min(1.0, base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus))
# Calculate final confidence
final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus
return max(0.1, min(1.0, final_confidence)) # Clamp between 0.1 and 1.0
def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str: def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str:
""" if cert_domain == query_domain: return 'exact_match'
Determine the context of the relationship between certificate domain and query domain. if cert_domain.endswith(f'.{query_domain}'): return 'subdomain'
if query_domain.endswith(f'.{cert_domain}'): return 'parent_domain'
Args:
cert_domain: Domain found in certificate
query_domain: Original query domain
Returns:
String describing the relationship context
"""
if cert_domain == query_domain:
return 'exact_match'
elif cert_domain.endswith(f'.{query_domain}'):
return 'subdomain'
elif query_domain.endswith(f'.{cert_domain}'):
return 'parent_domain'
else:
return 'related_domain' return 'related_domain'
def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
"""
Query crt.sh for certificates containing the IP address.
Note: crt.sh doesn't typically index by IP, so this returns empty results.
Args:
ip: IP address to investigate
Returns:
Empty list (crt.sh doesn't support IP-based certificate queries effectively)
"""
# crt.sh doesn't effectively support IP-based certificate queries
return [] return []
def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]:
"""
Extract all domains from certificate data.
Args:
cert_data: Certificate data from crt.sh API
Returns:
Set of unique domain names found in the certificate
"""
domains = set() domains = set()
if cn := cert_data.get('common_name'):
# Extract from common name if cleaned := self._clean_domain_name(cn):
common_name = cert_data.get('common_name', '') domains.update(cleaned)
if common_name: if nv := cert_data.get('name_value'):
cleaned_cn = self._clean_domain_name(common_name) for line in nv.split('\n'):
if cleaned_cn: if cleaned := self._clean_domain_name(line.strip()):
domains.update(cleaned_cn) domains.update(cleaned)
# Extract from name_value field (contains SANs)
name_value = cert_data.get('name_value', '')
if name_value:
# Split by newlines and clean each domain
for line in name_value.split('\n'):
cleaned_domains = self._clean_domain_name(line.strip())
if cleaned_domains:
domains.update(cleaned_domains)
return domains return domains
def _clean_domain_name(self, domain_name: str) -> List[str]: def _clean_domain_name(self, domain_name: str) -> List[str]:
""" if not domain_name: return []
Clean and normalize domain name from certificate data. domain = domain_name.strip().lower().split('://', 1)[-1].split('/', 1)[0]
Now returns a list to handle wildcards correctly. if ':' in domain and not domain.count(':') > 1: domain = domain.split(':', 1)[0]
""" cleaned_domains = [domain, domain[2:]] if domain.startswith('*.') else [domain]
if not domain_name:
return []
domain = domain_name.strip().lower()
# Remove protocol if present
if domain.startswith(('http://', 'https://')):
domain = domain.split('://', 1)[1]
# Remove path if present
if '/' in domain:
domain = domain.split('/', 1)[0]
# Remove port if present
if ':' in domain and not domain.count(':') > 1: # Avoid breaking IPv6
domain = domain.split(':', 1)[0]
# Handle wildcard domains
cleaned_domains = []
if domain.startswith('*.'):
# Add both the wildcard and the base domain
cleaned_domains.append(domain)
cleaned_domains.append(domain[2:])
else:
cleaned_domains.append(domain)
# Remove any remaining invalid characters and validate
final_domains = [] final_domains = []
for d in cleaned_domains: for d in cleaned_domains:
d = re.sub(r'[^\w\-\.]', '', d) d = re.sub(r'[^\w\-\.]', '', d)
if d and not d.startswith(('.', '-')) and not d.endswith(('.', '-')): if d and not d.startswith(('.', '-')) and not d.endswith(('.', '-')):
final_domains.append(d) final_domains.append(d)
return [d for d in final_domains if _is_valid_domain(d)] return [d for d in final_domains if _is_valid_domain(d)]

View File

@ -8,3 +8,4 @@ dnspython>=2.4.2
gunicorn gunicorn
redis redis
python-dotenv python-dotenv
psycopg2-binary