# dnsrecon/providers/crtsh_provider.py import json import re import os from pathlib import Path from typing import List, Dict, Any, Tuple, Set from datetime import datetime, timezone # New dependency required for this provider try: import psycopg2 import psycopg2.extras PSYCOPG2_AVAILABLE = True except ImportError: PSYCOPG2_AVAILABLE = False from .base_provider import BaseProvider from utils.helpers import _is_valid_domain # We use requests only to raise the same exception type for compatibility with core retry logic import requests class CrtShProvider(BaseProvider): """ Provider for querying crt.sh certificate transparency database via its public PostgreSQL endpoint. This version is designed to be a drop-in, high-performance replacement for the API-based provider. It preserves the same caching and data processing logic. """ def __init__(self, name=None, session_config=None): """Initialize CrtShDB provider with session-specific configuration.""" super().__init__( name="crtsh", rate_limit=0, # No rate limit for direct DB access timeout=60, # Increased timeout for potentially long DB queries session_config=session_config ) # Database connection details self.db_host = "crt.sh" self.db_port = 5432 self.db_name = "certwatch" self.db_user = "guest" self._stop_event = None # Initialize cache directory (same as original provider) self.cache_dir = Path('cache') / 'crtsh' self.cache_dir.mkdir(parents=True, exist_ok=True) def get_name(self) -> str: """Return the provider name.""" return "crtsh" def get_display_name(self) -> str: """Return the provider display name for the UI.""" return "crt.sh (DB)" def requires_api_key(self) -> bool: """Return True if the provider requires an API key.""" return False def get_eligibility(self) -> Dict[str, bool]: """Return a dictionary indicating if the provider can query domains and/or IPs.""" return {'domains': True, 'ips': False} def is_available(self) -> bool: """ Check if the provider can be used. Requires the psycopg2 library. """ if not PSYCOPG2_AVAILABLE: self.logger.logger.warning("psycopg2 library not found. CrtShDBProvider is unavailable. " "Please run 'pip install psycopg2-binary'.") return False return True def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]: """ Query the crt.sh PostgreSQL database for raw certificate data. Raises exceptions for DB/network errors to allow core logic to retry. """ conn = None certificates = [] # SQL Query to find all certificate IDs related to the domain (including subdomains), # then retrieve comprehensive details for each certificate, mimicking the JSON API structure. sql_query = """ WITH certificates_of_interest AS ( SELECT DISTINCT ci.certificate_id FROM certificate_identity ci WHERE ci.name_value ILIKE %(domain_wildcard)s OR ci.name_value = %(domain)s ) SELECT c.id, c.serial_number, c.not_before, c.not_after, (SELECT min(entry_timestamp) FROM ct_log_entry cle WHERE cle.certificate_id = c.id) as entry_timestamp, ca.id as issuer_ca_id, ca.name as issuer_name, (SELECT array_to_string(array_agg(DISTINCT ci.name_value), E'\n') FROM certificate_identity ci WHERE ci.certificate_id = c.id) as name_value, (SELECT name_value FROM certificate_identity ci WHERE ci.certificate_id = c.id AND ci.name_type = 'commonName' LIMIT 1) as common_name FROM certificate c JOIN ca ON c.issuer_ca_id = ca.id WHERE c.id IN (SELECT certificate_id FROM certificates_of_interest); """ try: conn = psycopg2.connect( dbname=self.db_name, user=self.db_user, host=self.db_host, port=self.db_port, connect_timeout=self.timeout ) with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: cursor.execute(sql_query, {'domain': domain, 'domain_wildcard': f'%.{domain}'}) results = cursor.fetchall() certificates = [dict(row) for row in results] self.logger.logger.info(f"crt.sh DB query for '{domain}' returned {len(certificates)} certificates.") except psycopg2.Error as e: self.logger.logger.error(f"PostgreSQL query failed for {domain}: {e}") # Raise a RequestException to be compatible with the existing retry logic in the core application raise requests.exceptions.RequestException(f"PostgreSQL query failed: {e}") from e finally: if conn: conn.close() return certificates def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: """ Query crt.sh for certificates containing the domain with caching support. Properly raises exceptions for network errors to allow core logic retries. """ if not _is_valid_domain(domain): return [] if self._stop_event and self._stop_event.is_set(): return [] cache_file = self._get_cache_file_path(domain) cache_status = self._get_cache_status(cache_file) certificates = [] try: if cache_status == "fresh": certificates = self._load_cached_certificates(cache_file) self.logger.logger.info(f"Using cached data for {domain} ({len(certificates)} certificates)") elif cache_status == "not_found": # Fresh query from DB, create new cache certificates = self._query_crtsh(domain) if certificates: self._create_cache_file(cache_file, domain, self._serialize_certs_for_cache(certificates)) else: self.logger.logger.info(f"No certificates found for {domain}, not caching") elif cache_status == "stale": try: new_certificates = self._query_crtsh(domain) if new_certificates: certificates = self._append_to_cache(cache_file, self._serialize_certs_for_cache(new_certificates)) else: certificates = self._load_cached_certificates(cache_file) except requests.exceptions.RequestException: certificates = self._load_cached_certificates(cache_file) if certificates: self.logger.logger.warning(f"DB query failed for {domain}, using stale cache data.") else: raise except requests.exceptions.RequestException as e: # Re-raise so core logic can retry self.logger.logger.error(f"DB query failed for {domain}: {e}") raise e except json.JSONDecodeError as e: # JSON parsing errors from cache should also be handled self.logger.logger.error(f"Failed to parse JSON from cache for {domain}: {e}") raise e if self._stop_event and self._stop_event.is_set(): return [] if not certificates: return [] return self._process_certificates_to_relationships(domain, certificates) def _serialize_certs_for_cache(self, certificates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Serialize certificate data for JSON caching, converting datetime objects to ISO strings. """ serialized_certs = [] for cert in certificates: serialized_cert = cert.copy() for key in ['not_before', 'not_after', 'entry_timestamp']: if isinstance(serialized_cert.get(key), datetime): # Ensure datetime is timezone-aware before converting dt_obj = serialized_cert[key] if dt_obj.tzinfo is None: dt_obj = dt_obj.replace(tzinfo=timezone.utc) serialized_cert[key] = dt_obj.isoformat() serialized_certs.append(serialized_cert) return serialized_certs # --- All methods below are copied directly from the original CrtShProvider --- # They are compatible because _query_crtsh returns data in the same format # as the original _query_crtsh_api method. A small adjustment is made to # _parse_certificate_date to handle datetime objects directly from the DB. def _get_cache_file_path(self, domain: str) -> Path: """Generate cache file path for a domain.""" safe_domain = domain.replace('.', '_').replace('/', '_').replace('\\', '_') return self.cache_dir / f"{safe_domain}.json" def _get_cache_status(self, cache_file_path: Path) -> str: """Check cache status for a domain.""" if not cache_file_path.exists(): return "not_found" try: with open(cache_file_path, 'r') as f: cache_data = json.load(f) last_query_str = cache_data.get("last_upstream_query") if not last_query_str: return "stale" last_query = datetime.fromisoformat(last_query_str.replace('Z', '+00:00')) hours_since_query = (datetime.now(timezone.utc) - last_query).total_seconds() / 3600 cache_timeout = self.config.cache_timeout_hours if hours_since_query < cache_timeout: return "fresh" else: return "stale" except (json.JSONDecodeError, ValueError, KeyError) as e: self.logger.logger.warning(f"Invalid cache file format for {cache_file_path}: {e}") return "stale" def _load_cached_certificates(self, cache_file_path: Path) -> List[Dict[str, Any]]: """Load certificates from cache file.""" try: with open(cache_file_path, 'r') as f: cache_data = json.load(f) return cache_data.get('certificates', []) except (json.JSONDecodeError, FileNotFoundError, KeyError) as e: self.logger.logger.error(f"Failed to load cached certificates from {cache_file_path}: {e}") return [] def _create_cache_file(self, cache_file_path: Path, domain: str, certificates: List[Dict[str, Any]]) -> None: """Create new cache file with certificates.""" try: cache_data = { "domain": domain, "first_cached": datetime.now(timezone.utc).isoformat(), "last_upstream_query": datetime.now(timezone.utc).isoformat(), "upstream_query_count": 1, "certificates": certificates } cache_file_path.parent.mkdir(parents=True, exist_ok=True) with open(cache_file_path, 'w') as f: json.dump(cache_data, f, separators=(',', ':')) self.logger.logger.info(f"Created cache file for {domain} with {len(certificates)} certificates") except Exception as e: self.logger.logger.warning(f"Failed to create cache file for {domain}: {e}") def _append_to_cache(self, cache_file_path: Path, new_certificates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Append new certificates to existing cache and return all certificates.""" try: with open(cache_file_path, 'r') as f: cache_data = json.load(f) existing_ids = {cert.get('id') for cert in cache_data.get('certificates', [])} added_count = 0 for cert in new_certificates: cert_id = cert.get('id') if cert_id and cert_id not in existing_ids: cache_data['certificates'].append(cert) existing_ids.add(cert_id) added_count += 1 cache_data['last_upstream_query'] = datetime.now(timezone.utc).isoformat() cache_data['upstream_query_count'] = cache_data.get('upstream_query_count', 0) + 1 with open(cache_file_path, 'w') as f: json.dump(cache_data, f, separators=(',', ':')) total_certs = len(cache_data['certificates']) self.logger.logger.info(f"Appended {added_count} new certificates to cache. Total: {total_certs}") return cache_data['certificates'] except Exception as e: self.logger.logger.warning(f"Failed to append to cache: {e}") return new_certificates def _parse_issuer_organization(self, issuer_dn: str) -> str: """Parse the issuer Distinguished Name to extract just the organization name.""" if not issuer_dn: return issuer_dn try: components = [comp.strip() for comp in issuer_dn.split(',')] for component in components: if component.startswith('O='): org_name = component[2:].strip() if org_name.startswith('"') and org_name.endswith('"'): org_name = org_name[1:-1] return org_name return issuer_dn except Exception as e: self.logger.logger.debug(f"Failed to parse issuer DN '{issuer_dn}': {e}") return issuer_dn def _parse_certificate_date(self, date_input: Any) -> datetime: """ Parse certificate date from various formats (string from cache, datetime from DB). """ if isinstance(date_input, datetime): # If it's already a datetime object from the DB, just ensure it's UTC if date_input.tzinfo is None: return date_input.replace(tzinfo=timezone.utc) return date_input date_string = str(date_input) if not date_string: raise ValueError("Empty date string") try: if 'Z' in date_string: return datetime.fromisoformat(date_string.replace('Z', '+00:00')) # Handle standard ISO format with or without timezone dt = datetime.fromisoformat(date_string) if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) return dt except ValueError as e: try: # Fallback for other formats return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) except Exception: raise ValueError(f"Unable to parse date: {date_string}") from e def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool: """Check if a certificate is currently valid based on its expiry date.""" try: not_after_str = cert_data.get('not_after') if not not_after_str: return False not_after_date = self._parse_certificate_date(not_after_str) not_before_str = cert_data.get('not_before') now = datetime.now(timezone.utc) is_not_expired = not_after_date > now if not_before_str: not_before_date = self._parse_certificate_date(not_before_str) is_not_before_valid = not_before_date <= now return is_not_expired and is_not_before_valid return is_not_expired except Exception as e: self.logger.logger.debug(f"Certificate validity check failed: {e}") return False def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]: # This method works as-is. raw_issuer_name = cert_data.get('issuer_name', '') parsed_issuer_name = self._parse_issuer_organization(raw_issuer_name) metadata = { 'certificate_id': cert_data.get('id'), 'serial_number': cert_data.get('serial_number'), 'issuer_name': parsed_issuer_name, 'issuer_ca_id': cert_data.get('issuer_ca_id'), 'common_name': cert_data.get('common_name'), 'not_before': cert_data.get('not_before'), 'not_after': cert_data.get('not_after'), 'entry_timestamp': cert_data.get('entry_timestamp'), 'source': 'crt.sh (DB)' } try: if metadata['not_before'] and metadata['not_after']: not_before = self._parse_certificate_date(metadata['not_before']) not_after = self._parse_certificate_date(metadata['not_after']) metadata['validity_period_days'] = (not_after - not_before).days metadata['is_currently_valid'] = self._is_cert_valid(cert_data) metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30 metadata['not_before'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC') metadata['not_after'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC') except Exception as e: self.logger.logger.debug(f"Error computing certificate metadata: {e}") metadata['is_currently_valid'] = False metadata['expires_soon'] = False return metadata def _process_certificates_to_relationships(self, domain: str, certificates: List[Dict[str, Any]]) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: # This method works as-is. relationships = [] if self._stop_event and self._stop_event.is_set(): return [] domain_certificates = {} all_discovered_domains = set() for i, cert_data in enumerate(certificates): if i % 5 == 0 and self._stop_event and self._stop_event.is_set(): break cert_metadata = self._extract_certificate_metadata(cert_data) cert_domains = self._extract_domains_from_certificate(cert_data) all_discovered_domains.update(cert_domains) for cert_domain in cert_domains: if not _is_valid_domain(cert_domain): continue if cert_domain not in domain_certificates: domain_certificates[cert_domain] = [] domain_certificates[cert_domain].append(cert_metadata) if self._stop_event and self._stop_event.is_set(): return [] for i, discovered_domain in enumerate(all_discovered_domains): if discovered_domain == domain: continue if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): break if not _is_valid_domain(discovered_domain): continue query_domain_certs = domain_certificates.get(domain, []) discovered_domain_certs = domain_certificates.get(discovered_domain, []) shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs) confidence = self._calculate_domain_relationship_confidence( domain, discovered_domain, shared_certificates, all_discovered_domains ) relationship_raw_data = { 'relationship_type': 'certificate_discovery', 'shared_certificates': shared_certificates, 'total_shared_certs': len(shared_certificates), 'discovery_context': self._determine_relationship_context(discovered_domain, domain), 'domain_certificates': { domain: self._summarize_certificates(query_domain_certs), discovered_domain: self._summarize_certificates(discovered_domain_certs) } } relationships.append(( domain, discovered_domain, 'san_certificate', confidence, relationship_raw_data )) self.log_relationship_discovery( source_node=domain, target_node=discovered_domain, relationship_type='san_certificate', confidence_score=confidence, raw_data=relationship_raw_data, discovery_method="certificate_transparency_analysis" ) return relationships # --- All remaining helper methods are identical to the original and fully compatible --- # They are included here for completeness. def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]: cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')} return [cert for cert in certs2 if cert.get('certificate_id') in cert1_ids] def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]: if not certificates: return {'total_certificates': 0, 'valid_certificates': 0, 'expired_certificates': 0, 'expires_soon_count': 0, 'unique_issuers': [], 'latest_certificate': None, 'has_valid_cert': False} valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid')) expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon')) unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name'))) latest_cert, latest_date = None, None for cert in certificates: try: if cert.get('not_before'): cert_date = self._parse_certificate_date(cert['not_before']) if latest_date is None or cert_date > latest_date: latest_date, latest_cert = cert_date, cert except Exception: continue return {'total_certificates': len(certificates), 'valid_certificates': valid_count, 'expired_certificates': len(certificates) - valid_count, 'expires_soon_count': expires_soon_count, 'unique_issuers': unique_issuers, 'latest_certificate': latest_cert, 'has_valid_cert': valid_count > 0, 'certificate_details': certificates} def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str, shared_certificates: List[Dict[str, Any]], all_discovered_domains: Set[str]) -> float: base_confidence, context_bonus, shared_bonus, validity_bonus, issuer_bonus = 0.9, 0.0, 0.0, 0.0, 0.0 relationship_context = self._determine_relationship_context(domain2, domain1) if relationship_context == 'subdomain': context_bonus = 0.1 elif relationship_context == 'parent_domain': context_bonus = 0.05 if shared_certificates: if len(shared_certificates) >= 3: shared_bonus = 0.1 elif len(shared_certificates) >= 2: shared_bonus = 0.05 else: shared_bonus = 0.02 if any(cert.get('is_currently_valid') for cert in shared_certificates): validity_bonus = 0.05 for cert in shared_certificates: if any(ca in cert.get('issuer_name', '').lower() for ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']): issuer_bonus = max(issuer_bonus, 0.03) break return max(0.1, min(1.0, base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus)) def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str: if cert_domain == query_domain: return 'exact_match' if cert_domain.endswith(f'.{query_domain}'): return 'subdomain' if query_domain.endswith(f'.{cert_domain}'): return 'parent_domain' return 'related_domain' def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]: return [] def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]: domains = set() if cn := cert_data.get('common_name'): if cleaned := self._clean_domain_name(cn): domains.update(cleaned) if nv := cert_data.get('name_value'): for line in nv.split('\n'): if cleaned := self._clean_domain_name(line.strip()): domains.update(cleaned) return domains def _clean_domain_name(self, domain_name: str) -> List[str]: if not domain_name: return [] domain = domain_name.strip().lower().split('://', 1)[-1].split('/', 1)[0] if ':' in domain and not domain.count(':') > 1: domain = domain.split(':', 1)[0] cleaned_domains = [domain, domain[2:]] if domain.startswith('*.') else [domain] final_domains = [] for d in cleaned_domains: d = re.sub(r'[^\w\-\.]', '', d) if d and not d.startswith(('.', '-')) and not d.endswith(('.', '-')): final_domains.append(d) return [d for d in final_domains if _is_valid_domain(d)]