try to implement websockets

This commit is contained in:
overcuriousity
2025-09-20 14:17:17 +02:00
parent 3ee23c9d05
commit 75a595c9cb
11 changed files with 116 additions and 251 deletions

View File

@@ -53,7 +53,7 @@ class BaseProvider(ABC):
def __getstate__(self):
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
state = self.__dict__.copy()
# Exclude the unpickleable '_local' attribute and stop event
# Exclude the unpickleable '_local' attribute (which holds the session) and stop event
unpicklable_attrs = ['_local', '_stop_event']
for attr in unpicklable_attrs:
if attr in state:

View File

@@ -26,6 +26,7 @@ class CorrelationProvider(BaseProvider):
'cert_common_name',
'cert_validity_period_days',
'cert_issuer_name',
'cert_serial_number',
'cert_entry_timestamp',
'cert_not_before',
'cert_not_after',

View File

@@ -2,38 +2,17 @@
import json
import re
import psycopg2
from pathlib import Path
from typing import List, Dict, Any, Set, Optional
from urllib.parse import quote
from datetime import datetime, timezone
import requests
from psycopg2 import pool
from .base_provider import BaseProvider
from core.provider_result import ProviderResult
from utils.helpers import _is_valid_domain
from core.logger import get_forensic_logger
# --- Global Instance for PostgreSQL Connection Pool ---
# This pool will be created once per worker process and is not part of the
# CrtShProvider instance, thus avoiding pickling errors.
db_pool = None
try:
db_pool = psycopg2.pool.SimpleConnectionPool(
1, 5,
host='crt.sh',
port=5432,
user='guest',
dbname='certwatch',
sslmode='prefer',
connect_timeout=60
)
# Use a generic logger here as this is at the module level
get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.")
except Exception as e:
get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.")
class CrtShProvider(BaseProvider):
"""
@@ -136,51 +115,42 @@ class CrtShProvider(BaseProvider):
result = ProviderResult()
try:
if cache_status == "fresh":
result = self._load_from_cache(cache_file)
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
if cache_status == "fresh":
result = self._load_from_cache(cache_file)
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
else: # "stale" or "not_found"
# Query the API for the latest certificates
new_raw_certs = self._query_crtsh_api(domain)
else: # "stale" or "not_found"
# Query the API for the latest certificates
new_raw_certs = self._query_crtsh(domain)
if self._stop_event and self._stop_event.is_set():
return ProviderResult()
# Combine with old data if cache is stale
if cache_status == "stale":
old_raw_certs = self._load_raw_data_from_cache(cache_file)
combined_certs = old_raw_certs + new_raw_certs
if self._stop_event and self._stop_event.is_set():
return ProviderResult()
# Combine with old data if cache is stale
if cache_status == "stale":
old_raw_certs = self._load_raw_data_from_cache(cache_file)
combined_certs = old_raw_certs + new_raw_certs
# Deduplicate the combined list
seen_ids = set()
unique_certs = []
for cert in combined_certs:
cert_id = cert.get('id')
if cert_id not in seen_ids:
unique_certs.append(cert)
seen_ids.add(cert_id)
raw_certificates_to_process = unique_certs
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
else: # "not_found"
raw_certificates_to_process = new_raw_certs
# Deduplicate the combined list
seen_ids = set()
unique_certs = []
for cert in combined_certs:
cert_id = cert.get('id')
if cert_id not in seen_ids:
unique_certs.append(cert)
seen_ids.add(cert_id)
# FIXED: Process certificates to create proper domain and CA nodes
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
raw_certificates_to_process = unique_certs
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
else: # "not_found"
raw_certificates_to_process = new_raw_certs
# FIXED: Process certificates to create proper domain and CA nodes
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
# Save the new result and the raw data to the cache
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
except (requests.exceptions.RequestException, psycopg2.Error) as e:
self.logger.logger.error(f"Upstream query failed for {domain}: {e}")
if cache_status != "not_found":
result = self._load_from_cache(cache_file)
self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.")
else:
raise e # Re-raise if there's no cache to fall back on
# Save the new result and the raw data to the cache
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
return result
@@ -277,58 +247,6 @@ class CrtShProvider(BaseProvider):
json.dump(cache_data, f, separators=(',', ':'), default=str)
except Exception as e:
self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}")
def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh, trying the database first and falling back to the API."""
global db_pool
if db_pool:
try:
self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}")
return self._query_crtsh_db(domain)
except psycopg2.Error as e:
self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.")
return self._query_crtsh_api(domain)
else:
self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}")
return self._query_crtsh_api(domain)
def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh database for raw certificate data."""
global db_pool
conn = db_pool.getconn()
try:
with conn.cursor() as cursor:
query = """
SELECT
c.id,
x509_serialnumber(c.certificate) as serial_number,
x509_notbefore(c.certificate) as not_before,
x509_notafter(c.certificate) as not_after,
c.issuer_ca_id,
ca.name as issuer_name,
x509_commonname(c.certificate) as common_name,
identities(c.certificate)::text as name_value
FROM certificate c
LEFT JOIN ca ON c.issuer_ca_id = ca.id
WHERE identities(c.certificate) @@ plainto_tsquery(%s)
ORDER BY c.id DESC
LIMIT 5000;
"""
cursor.execute(query, (domain,))
results = []
columns = [desc[0] for desc in cursor.description]
for row in cursor.fetchall():
row_dict = dict(zip(columns, row))
if row_dict.get('not_before'):
row_dict['not_before'] = row_dict['not_before'].isoformat()
if row_dict.get('not_after'):
row_dict['not_after'] = row_dict['not_after'].isoformat()
results.append(row_dict)
self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.")
return results
finally:
db_pool.putconn(conn)
def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]:
"""Query crt.sh API for raw certificate data."""

View File

@@ -27,6 +27,21 @@ class DNSProvider(BaseProvider):
self.resolver.timeout = 5
self.resolver.lifetime = 10
def __getstate__(self):
"""Prepare the object for pickling."""
state = self.__dict__.copy()
# Remove the unpickleable 'resolver' attribute
if 'resolver' in state:
del state['resolver']
return state
def __setstate__(self, state):
"""Restore the object after unpickling."""
self.__dict__.update(state)
# Re-initialize the 'resolver' attribute
self.resolver = resolver.Resolver()
self.resolver.timeout = 5
def get_name(self) -> str:
"""Return the provider name."""
return "dns"

View File

@@ -36,6 +36,15 @@ class ShodanProvider(BaseProvider):
self.cache_dir = Path('cache') / 'shodan'
self.cache_dir.mkdir(parents=True, exist_ok=True)
def __getstate__(self):
"""Prepare the object for pickling."""
state = super().__getstate__()
return state
def __setstate__(self, state):
"""Restore the object after unpickling."""
super().__setstate__(state)
def _check_api_connection(self) -> bool:
"""
FIXED: Lazy connection checking - only test when actually needed.