diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index e4bda95..c946c74 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -2,15 +2,37 @@ import json import re +import psycopg2 from pathlib import Path from typing import List, Dict, Any, Set, Optional from urllib.parse import quote from datetime import datetime, timezone import requests +from psycopg2 import pool from .base_provider import BaseProvider from core.provider_result import ProviderResult from utils.helpers import _is_valid_domain +from core.logger import get_forensic_logger + +# --- Global Instance for PostgreSQL Connection Pool --- +# This pool will be created once per worker process and is not part of the +# CrtShProvider instance, thus avoiding pickling errors. +db_pool = None +try: + db_pool = psycopg2.pool.SimpleConnectionPool( + 1, 5, + host='crt.sh', + port=5432, + user='guest', + dbname='certwatch', + sslmode='prefer', + connect_timeout=60 + ) + # Use a generic logger here as this is at the module level + get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.") +except Exception as e: + get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.") class CrtShProvider(BaseProvider): @@ -37,7 +59,7 @@ class CrtShProvider(BaseProvider): # Compile regex for date filtering for efficiency self.date_pattern = re.compile(r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}') - + def get_name(self) -> str: """Return the provider name.""" return "crtsh" @@ -121,7 +143,7 @@ class CrtShProvider(BaseProvider): else: # "stale" or "not_found" # Query the API for the latest certificates - new_raw_certs = self._query_crtsh_api(domain) + new_raw_certs = self._query_crtsh(domain) if self._stop_event and self._stop_event.is_set(): return ProviderResult() @@ -152,8 +174,8 @@ class CrtShProvider(BaseProvider): # Save the new result and the raw data to the cache self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain) - except requests.exceptions.RequestException as e: - self.logger.logger.error(f"API query failed for {domain}: {e}") + except (requests.exceptions.RequestException, psycopg2.Error) as e: + self.logger.logger.error(f"Upstream query failed for {domain}: {e}") if cache_status != "not_found": result = self._load_from_cache(cache_file) self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.") @@ -255,6 +277,58 @@ class CrtShProvider(BaseProvider): json.dump(cache_data, f, separators=(',', ':'), default=str) except Exception as e: self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}") + + def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]: + """Query crt.sh, trying the database first and falling back to the API.""" + global db_pool + if db_pool: + try: + self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}") + return self._query_crtsh_db(domain) + except psycopg2.Error as e: + self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.") + return self._query_crtsh_api(domain) + else: + self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}") + return self._query_crtsh_api(domain) + + def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]: + """Query crt.sh database for raw certificate data.""" + global db_pool + conn = db_pool.getconn() + try: + with conn.cursor() as cursor: + query = """ + SELECT + c.id, + x509_serialnumber(c.certificate) as serial_number, + x509_notbefore(c.certificate) as not_before, + x509_notafter(c.certificate) as not_after, + c.issuer_ca_id, + ca.name as issuer_name, + x509_commonname(c.certificate) as common_name, + identities(c.certificate)::text as name_value + FROM certificate c + LEFT JOIN ca ON c.issuer_ca_id = ca.id + WHERE identities(c.certificate) @@ plainto_tsquery(%s) + ORDER BY c.id DESC + LIMIT 5000; + """ + cursor.execute(query, (domain,)) + + results = [] + columns = [desc[0] for desc in cursor.description] + for row in cursor.fetchall(): + row_dict = dict(zip(columns, row)) + if row_dict.get('not_before'): + row_dict['not_before'] = row_dict['not_before'].isoformat() + if row_dict.get('not_after'): + row_dict['not_after'] = row_dict['not_after'].isoformat() + results.append(row_dict) + self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.") + return results + finally: + db_pool.putconn(conn) def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]: """Query crt.sh API for raw certificate data.""" @@ -468,6 +542,8 @@ class CrtShProvider(BaseProvider): raise ValueError("Empty date string") try: + if isinstance(date_string, datetime): + return date_string.replace(tzinfo=timezone.utc) if date_string.endswith('Z'): return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc) elif '+' in date_string or date_string.endswith('UTC'): diff --git a/requirements.txt b/requirements.txt index d46c0bc..4ec5adb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ urllib3 dnspython gunicorn redis -python-dotenv \ No newline at end of file +python-dotenv +psycopg2-binary \ No newline at end of file