try to implement websockets
This commit is contained in:
@@ -53,7 +53,7 @@ class BaseProvider(ABC):
|
||||
def __getstate__(self):
|
||||
"""Prepare BaseProvider for pickling by excluding unpicklable objects."""
|
||||
state = self.__dict__.copy()
|
||||
# Exclude the unpickleable '_local' attribute and stop event
|
||||
# Exclude the unpickleable '_local' attribute (which holds the session) and stop event
|
||||
unpicklable_attrs = ['_local', '_stop_event']
|
||||
for attr in unpicklable_attrs:
|
||||
if attr in state:
|
||||
|
||||
@@ -26,6 +26,7 @@ class CorrelationProvider(BaseProvider):
|
||||
'cert_common_name',
|
||||
'cert_validity_period_days',
|
||||
'cert_issuer_name',
|
||||
'cert_serial_number',
|
||||
'cert_entry_timestamp',
|
||||
'cert_not_before',
|
||||
'cert_not_after',
|
||||
|
||||
@@ -2,38 +2,17 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import psycopg2
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Set, Optional
|
||||
from urllib.parse import quote
|
||||
from datetime import datetime, timezone
|
||||
import requests
|
||||
from psycopg2 import pool
|
||||
|
||||
from .base_provider import BaseProvider
|
||||
from core.provider_result import ProviderResult
|
||||
from utils.helpers import _is_valid_domain
|
||||
from core.logger import get_forensic_logger
|
||||
|
||||
# --- Global Instance for PostgreSQL Connection Pool ---
|
||||
# This pool will be created once per worker process and is not part of the
|
||||
# CrtShProvider instance, thus avoiding pickling errors.
|
||||
db_pool = None
|
||||
try:
|
||||
db_pool = psycopg2.pool.SimpleConnectionPool(
|
||||
1, 5,
|
||||
host='crt.sh',
|
||||
port=5432,
|
||||
user='guest',
|
||||
dbname='certwatch',
|
||||
sslmode='prefer',
|
||||
connect_timeout=60
|
||||
)
|
||||
# Use a generic logger here as this is at the module level
|
||||
get_forensic_logger().logger.info("crt.sh: Global PostgreSQL connection pool created successfully.")
|
||||
except Exception as e:
|
||||
get_forensic_logger().logger.warning(f"crt.sh: Failed to create global DB connection pool: {e}. Will fall back to HTTP API.")
|
||||
|
||||
|
||||
class CrtShProvider(BaseProvider):
|
||||
"""
|
||||
@@ -136,51 +115,42 @@ class CrtShProvider(BaseProvider):
|
||||
|
||||
result = ProviderResult()
|
||||
|
||||
try:
|
||||
if cache_status == "fresh":
|
||||
result = self._load_from_cache(cache_file)
|
||||
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
|
||||
if cache_status == "fresh":
|
||||
result = self._load_from_cache(cache_file)
|
||||
self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}")
|
||||
|
||||
else: # "stale" or "not_found"
|
||||
# Query the API for the latest certificates
|
||||
new_raw_certs = self._query_crtsh_api(domain)
|
||||
|
||||
else: # "stale" or "not_found"
|
||||
# Query the API for the latest certificates
|
||||
new_raw_certs = self._query_crtsh(domain)
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
return ProviderResult()
|
||||
|
||||
# Combine with old data if cache is stale
|
||||
if cache_status == "stale":
|
||||
old_raw_certs = self._load_raw_data_from_cache(cache_file)
|
||||
combined_certs = old_raw_certs + new_raw_certs
|
||||
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
return ProviderResult()
|
||||
|
||||
# Combine with old data if cache is stale
|
||||
if cache_status == "stale":
|
||||
old_raw_certs = self._load_raw_data_from_cache(cache_file)
|
||||
combined_certs = old_raw_certs + new_raw_certs
|
||||
|
||||
# Deduplicate the combined list
|
||||
seen_ids = set()
|
||||
unique_certs = []
|
||||
for cert in combined_certs:
|
||||
cert_id = cert.get('id')
|
||||
if cert_id not in seen_ids:
|
||||
unique_certs.append(cert)
|
||||
seen_ids.add(cert_id)
|
||||
|
||||
raw_certificates_to_process = unique_certs
|
||||
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
|
||||
else: # "not_found"
|
||||
raw_certificates_to_process = new_raw_certs
|
||||
# Deduplicate the combined list
|
||||
seen_ids = set()
|
||||
unique_certs = []
|
||||
for cert in combined_certs:
|
||||
cert_id = cert.get('id')
|
||||
if cert_id not in seen_ids:
|
||||
unique_certs.append(cert)
|
||||
seen_ids.add(cert_id)
|
||||
|
||||
# FIXED: Process certificates to create proper domain and CA nodes
|
||||
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
|
||||
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
|
||||
raw_certificates_to_process = unique_certs
|
||||
self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}")
|
||||
else: # "not_found"
|
||||
raw_certificates_to_process = new_raw_certs
|
||||
|
||||
# FIXED: Process certificates to create proper domain and CA nodes
|
||||
result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process)
|
||||
self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)")
|
||||
|
||||
# Save the new result and the raw data to the cache
|
||||
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
|
||||
|
||||
except (requests.exceptions.RequestException, psycopg2.Error) as e:
|
||||
self.logger.logger.error(f"Upstream query failed for {domain}: {e}")
|
||||
if cache_status != "not_found":
|
||||
result = self._load_from_cache(cache_file)
|
||||
self.logger.logger.warning(f"Using stale cache for {domain} due to API failure.")
|
||||
else:
|
||||
raise e # Re-raise if there's no cache to fall back on
|
||||
# Save the new result and the raw data to the cache
|
||||
self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain)
|
||||
|
||||
return result
|
||||
|
||||
@@ -277,58 +247,6 @@ class CrtShProvider(BaseProvider):
|
||||
json.dump(cache_data, f, separators=(',', ':'), default=str)
|
||||
except Exception as e:
|
||||
self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}")
|
||||
|
||||
def _query_crtsh(self, domain: str) -> List[Dict[str, Any]]:
|
||||
"""Query crt.sh, trying the database first and falling back to the API."""
|
||||
global db_pool
|
||||
if db_pool:
|
||||
try:
|
||||
self.logger.logger.info(f"crt.sh: Attempting DB query for {domain}")
|
||||
return self._query_crtsh_db(domain)
|
||||
except psycopg2.Error as e:
|
||||
self.logger.logger.warning(f"crt.sh: DB query failed for {domain}: {e}. Falling back to HTTP API.")
|
||||
return self._query_crtsh_api(domain)
|
||||
else:
|
||||
self.logger.logger.info(f"crt.sh: No DB connection pool. Using HTTP API for {domain}")
|
||||
return self._query_crtsh_api(domain)
|
||||
|
||||
def _query_crtsh_db(self, domain: str) -> List[Dict[str, Any]]:
|
||||
"""Query crt.sh database for raw certificate data."""
|
||||
global db_pool
|
||||
conn = db_pool.getconn()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
query = """
|
||||
SELECT
|
||||
c.id,
|
||||
x509_serialnumber(c.certificate) as serial_number,
|
||||
x509_notbefore(c.certificate) as not_before,
|
||||
x509_notafter(c.certificate) as not_after,
|
||||
c.issuer_ca_id,
|
||||
ca.name as issuer_name,
|
||||
x509_commonname(c.certificate) as common_name,
|
||||
identities(c.certificate)::text as name_value
|
||||
FROM certificate c
|
||||
LEFT JOIN ca ON c.issuer_ca_id = ca.id
|
||||
WHERE identities(c.certificate) @@ plainto_tsquery(%s)
|
||||
ORDER BY c.id DESC
|
||||
LIMIT 5000;
|
||||
"""
|
||||
cursor.execute(query, (domain,))
|
||||
|
||||
results = []
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
for row in cursor.fetchall():
|
||||
row_dict = dict(zip(columns, row))
|
||||
if row_dict.get('not_before'):
|
||||
row_dict['not_before'] = row_dict['not_before'].isoformat()
|
||||
if row_dict.get('not_after'):
|
||||
row_dict['not_after'] = row_dict['not_after'].isoformat()
|
||||
results.append(row_dict)
|
||||
self.logger.logger.info(f"crt.sh: DB query for {domain} returned {len(results)} records.")
|
||||
return results
|
||||
finally:
|
||||
db_pool.putconn(conn)
|
||||
|
||||
def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]:
|
||||
"""Query crt.sh API for raw certificate data."""
|
||||
|
||||
@@ -27,6 +27,21 @@ class DNSProvider(BaseProvider):
|
||||
self.resolver.timeout = 5
|
||||
self.resolver.lifetime = 10
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare the object for pickling."""
|
||||
state = self.__dict__.copy()
|
||||
# Remove the unpickleable 'resolver' attribute
|
||||
if 'resolver' in state:
|
||||
del state['resolver']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore the object after unpickling."""
|
||||
self.__dict__.update(state)
|
||||
# Re-initialize the 'resolver' attribute
|
||||
self.resolver = resolver.Resolver()
|
||||
self.resolver.timeout = 5
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "dns"
|
||||
|
||||
@@ -36,6 +36,15 @@ class ShodanProvider(BaseProvider):
|
||||
self.cache_dir = Path('cache') / 'shodan'
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Prepare the object for pickling."""
|
||||
state = super().__getstate__()
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore the object after unpickling."""
|
||||
super().__setstate__(state)
|
||||
|
||||
def _check_api_connection(self) -> bool:
|
||||
"""
|
||||
FIXED: Lazy connection checking - only test when actually needed.
|
||||
|
||||
Reference in New Issue
Block a user