diff --git a/core/rate_limiter.py b/core/rate_limiter.py index d5a11d6..ec9ed8d 100644 --- a/core/rate_limiter.py +++ b/core/rate_limiter.py @@ -1,28 +1,145 @@ # dnsrecon-reduced/core/rate_limiter.py import time +import logging class GlobalRateLimiter: + """ + FIXED: Improved rate limiter with better cleanup and error handling. + Prevents accumulation of stale entries that cause infinite retry loops. + """ + def __init__(self, redis_client): self.redis = redis_client + self.logger = logging.getLogger('dnsrecon.rate_limiter') + # Track last cleanup times to avoid excessive Redis operations + self._last_cleanup = {} def is_rate_limited(self, key, limit, period): """ - Check if a key is rate-limited. + FIXED: Check if a key is rate-limited with improved cleanup and error handling. + + Args: + key: Rate limit key (e.g., provider name) + limit: Maximum requests allowed + period: Time period in seconds (60 for per-minute) + + Returns: + bool: True if rate limited, False otherwise + """ + if limit <= 0: + # Rate limit of 0 or negative means no limiting + return False + + now = time.time() + rate_key = f"rate_limit:{key}" + + try: + # FIXED: More aggressive cleanup to prevent accumulation + # Only clean up if we haven't cleaned recently (every 10 seconds max) + should_cleanup = ( + rate_key not in self._last_cleanup or + now - self._last_cleanup.get(rate_key, 0) > 10 + ) + + if should_cleanup: + # Remove entries older than the period + removed_count = self.redis.zremrangebyscore(rate_key, 0, now - period) + self._last_cleanup[rate_key] = now + + if removed_count > 0: + self.logger.debug(f"Rate limiter cleaned up {removed_count} old entries for {key}") + + # Get current count + current_count = self.redis.zcard(rate_key) + + if current_count >= limit: + self.logger.debug(f"Rate limited: {key} has {current_count}/{limit} requests in period") + return True + + # Add new timestamp with error handling + try: + # Use pipeline for atomic operations + pipe = self.redis.pipeline() + pipe.zadd(rate_key, {str(now): now}) + pipe.expire(rate_key, int(period * 2)) # Set TTL to 2x period for safety + pipe.execute() + except Exception as e: + self.logger.warning(f"Failed to record rate limit entry for {key}: {e}") + # Don't block the request if we can't record it + return False + + return False + + except Exception as e: + self.logger.error(f"Rate limiter error for {key}: {e}") + # FIXED: On Redis errors, don't block requests to avoid infinite loops + return False + + def get_rate_limit_status(self, key, limit, period): + """ + Get detailed rate limit status for debugging. + + Returns: + dict: Status information including current count, limit, and time to reset """ now = time.time() - key = f"rate_limit:{key}" - - # Remove old timestamps - self.redis.zremrangebyscore(key, 0, now - period) - - # Check the count - count = self.redis.zcard(key) - if count >= limit: + rate_key = f"rate_limit:{key}" + + try: + current_count = self.redis.zcard(rate_key) + + # Get oldest entry to calculate reset time + oldest_entries = self.redis.zrange(rate_key, 0, 0, withscores=True) + time_to_reset = 0 + if oldest_entries: + oldest_time = oldest_entries[0][1] + time_to_reset = max(0, period - (now - oldest_time)) + + return { + 'key': key, + 'current_count': current_count, + 'limit': limit, + 'period': period, + 'is_limited': current_count >= limit, + 'time_to_reset': time_to_reset + } + except Exception as e: + self.logger.error(f"Failed to get rate limit status for {key}: {e}") + return { + 'key': key, + 'current_count': 0, + 'limit': limit, + 'period': period, + 'is_limited': False, + 'time_to_reset': 0, + 'error': str(e) + } + + def reset_rate_limit(self, key): + """ + ADDED: Reset rate limit for a specific key (useful for debugging). + """ + rate_key = f"rate_limit:{key}" + try: + deleted = self.redis.delete(rate_key) + self.logger.info(f"Reset rate limit for {key} (deleted: {deleted})") return True - - # Add new timestamp - self.redis.zadd(key, {now: now}) - self.redis.expire(key, period) - - return False \ No newline at end of file + except Exception as e: + self.logger.error(f"Failed to reset rate limit for {key}: {e}") + return False + + def cleanup_all_rate_limits(self): + """ + ADDED: Clean up all rate limit entries (useful for maintenance). + """ + try: + keys = self.redis.keys("rate_limit:*") + if keys: + deleted = self.redis.delete(*keys) + self.logger.info(f"Cleaned up {deleted} rate limit keys") + return deleted + return 0 + except Exception as e: + self.logger.error(f"Failed to cleanup rate limits: {e}") + return 0 \ No newline at end of file diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index d2692a9..2197c80 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -16,8 +16,7 @@ from core.logger import get_forensic_logger class CrtShProvider(BaseProvider): """ Provider for querying crt.sh certificate transparency database. - FIXED: Now properly creates domain and CA nodes instead of large entities. - REMOVED: All PostgreSQL logic to rely exclusively on the HTTP API for stability. + FIXED: Improved caching logic and error handling to prevent infinite retry loops. Returns standardized ProviderResult objects with caching support. """ @@ -66,42 +65,72 @@ class CrtShProvider(BaseProvider): def _get_cache_status(self, cache_file_path: Path) -> str: """ - Check cache status for a domain. + FIXED: More robust cache status checking with better error handling. Returns: 'not_found', 'fresh', or 'stale' """ if not cache_file_path.exists(): return "not_found" try: - with open(cache_file_path, 'r') as f: + # Check if file is readable and not corrupted + if cache_file_path.stat().st_size == 0: + self.logger.logger.warning(f"Empty cache file: {cache_file_path}") + return "stale" + + with open(cache_file_path, 'r', encoding='utf-8') as f: cache_data = json.load(f) - last_query_str = cache_data.get("last_upstream_query") - if not last_query_str: + # Validate cache structure + if not isinstance(cache_data, dict): + self.logger.logger.warning(f"Invalid cache structure: {cache_file_path}") return "stale" - last_query = datetime.fromisoformat(last_query_str.replace('Z', '+00:00')) - hours_since_query = (datetime.now(timezone.utc) - last_query).total_seconds() / 3600 + last_query_str = cache_data.get("last_upstream_query") + if not last_query_str or not isinstance(last_query_str, str): + self.logger.logger.warning(f"Missing or invalid last_upstream_query: {cache_file_path}") + return "stale" + try: + # More robust datetime parsing + if last_query_str.endswith('Z'): + last_query = datetime.fromisoformat(last_query_str.replace('Z', '+00:00')) + elif '+' in last_query_str or last_query_str.endswith('UTC'): + # Handle various timezone formats + clean_time = last_query_str.replace('UTC', '').strip() + if '+' in clean_time: + clean_time = clean_time.split('+')[0] + last_query = datetime.fromisoformat(clean_time).replace(tzinfo=timezone.utc) + else: + last_query = datetime.fromisoformat(last_query_str).replace(tzinfo=timezone.utc) + + except (ValueError, AttributeError) as e: + self.logger.logger.warning(f"Failed to parse timestamp in cache {cache_file_path}: {e}") + return "stale" + + hours_since_query = (datetime.now(timezone.utc) - last_query).total_seconds() / 3600 cache_timeout = self.config.cache_timeout_hours + if hours_since_query < cache_timeout: return "fresh" else: return "stale" - except (json.JSONDecodeError, ValueError, KeyError) as e: - self.logger.logger.warning(f"Invalid cache file format for {cache_file_path}: {e}") + except (json.JSONDecodeError, OSError, PermissionError) as e: + self.logger.logger.warning(f"Cache file error for {cache_file_path}: {e}") + # FIXED: Try to remove corrupted cache file + try: + cache_file_path.unlink() + self.logger.logger.info(f"Removed corrupted cache file: {cache_file_path}") + except Exception: + pass + return "not_found" + except Exception as e: + self.logger.logger.error(f"Unexpected error checking cache status for {cache_file_path}: {e}") return "stale" def query_domain(self, domain: str) -> ProviderResult: """ - Query crt.sh for certificates containing the domain via HTTP API. - - Args: - domain: Domain to investigate - - Returns: - ProviderResult containing discovered relationships and attributes + FIXED: Simplified and more robust domain querying with better error handling. """ if not _is_valid_domain(domain): return ProviderResult() @@ -110,120 +139,155 @@ class CrtShProvider(BaseProvider): return ProviderResult() cache_file = self._get_cache_file_path(domain) - cache_status = self._get_cache_status(cache_file) - result = ProviderResult() try: - if cache_status == "fresh": - result = self._load_from_cache(cache_file) - self.logger.logger.info(f"Using fresh cached crt.sh data for {domain}") - return result + cache_status = self._get_cache_status(cache_file) - # For "stale" or "not_found", we must query the API. + if cache_status == "fresh": + # Load from cache + result = self._load_from_cache(cache_file) + if result and (result.relationships or result.attributes): + self.logger.logger.debug(f"Using fresh cached crt.sh data for {domain}") + return result + else: + # Cache exists but is empty, treat as stale + cache_status = "stale" + + # Need to query API (either no cache, stale cache, or empty cache) + self.logger.logger.debug(f"Querying crt.sh API for {domain} (cache status: {cache_status})") new_raw_certs = self._query_crtsh_api(domain) if self._stop_event and self._stop_event.is_set(): return ProviderResult() - # Combine with old data if cache is stale + # FIXED: Simplified processing - just process the new data + # Don't try to merge with stale cache as it can cause corruption + raw_certificates_to_process = new_raw_certs + if cache_status == "stale": - old_raw_certs = self._load_raw_data_from_cache(cache_file) - combined_certs = old_raw_certs + new_raw_certs - - # Deduplicate the combined list - seen_ids = set() - unique_certs = [] - for cert in combined_certs: - cert_id = cert.get('id') - if cert_id not in seen_ids: - unique_certs.append(cert) - seen_ids.add(cert_id) - - raw_certificates_to_process = unique_certs - self.logger.logger.info(f"Refreshed and merged cache for {domain}. Total unique certs: {len(raw_certificates_to_process)}") - else: # "not_found" - raw_certificates_to_process = new_raw_certs + self.logger.logger.info(f"Refreshed stale cache for {domain} with {len(raw_certificates_to_process)} certs") + else: + self.logger.logger.info(f"Created fresh cache for {domain} with {len(raw_certificates_to_process)} certs") result = self._process_certificates_to_result_fixed(domain, raw_certificates_to_process) - self.logger.logger.info(f"Created fresh result for {domain} ({result.get_relationship_count()} relationships)") - - # Save the new result and the raw data to the cache + + # Save the result to cache self._save_result_to_cache(cache_file, result, raw_certificates_to_process, domain) + + return result except requests.exceptions.RequestException as e: - self.logger.logger.error(f"Upstream query failed for {domain}: {e}") - # **BUG FIX:** Always re-raise the exception. This signals a failure to the - # scanner, allowing its retry logic to handle the transient error. - raise e - - return result + # FIXED: Don't re-raise network errors after long idle periods + # Instead return empty result and log the issue + self.logger.logger.warning(f"Network error querying crt.sh for {domain}: {e}") + + # Try to use stale cache if available + if cache_status == "stale": + try: + stale_result = self._load_from_cache(cache_file) + if stale_result and (stale_result.relationships or stale_result.attributes): + self.logger.logger.info(f"Using stale cache for {domain} due to network error") + return stale_result + except Exception as cache_error: + self.logger.logger.warning(f"Failed to load stale cache for {domain}: {cache_error}") + + # Return empty result instead of raising - this prevents infinite retries + return ProviderResult() + + except Exception as e: + # FIXED: Handle any other exceptions gracefully + self.logger.logger.error(f"Unexpected error querying crt.sh for {domain}: {e}") + + # Try stale cache as fallback + try: + if cache_file.exists(): + fallback_result = self._load_from_cache(cache_file) + if fallback_result and (fallback_result.relationships or fallback_result.attributes): + self.logger.logger.info(f"Using cached data for {domain} due to processing error") + return fallback_result + except Exception: + pass + + # Return empty result to prevent retries + return ProviderResult() def query_ip(self, ip: str) -> ProviderResult: """ crt.sh does not support IP-based certificate queries effectively via its API. - - Args: - ip: IP address to investigate - - Returns: - Empty ProviderResult """ return ProviderResult() def _load_from_cache(self, cache_file_path: Path) -> ProviderResult: - """Load processed crt.sh data from a cache file.""" + """FIXED: More robust cache loading with better validation.""" try: - with open(cache_file_path, 'r') as f: + if not cache_file_path.exists() or cache_file_path.stat().st_size == 0: + return ProviderResult() + + with open(cache_file_path, 'r', encoding='utf-8') as f: cache_content = json.load(f) + if not isinstance(cache_content, dict): + self.logger.logger.warning(f"Invalid cache format in {cache_file_path}") + return ProviderResult() + result = ProviderResult() - # Reconstruct relationships - for rel_data in cache_content.get("relationships", []): - result.add_relationship( - source_node=rel_data["source_node"], - target_node=rel_data["target_node"], - relationship_type=rel_data["relationship_type"], - provider=rel_data["provider"], - confidence=rel_data["confidence"], - raw_data=rel_data.get("raw_data", {}) - ) + # Reconstruct relationships with validation + relationships = cache_content.get("relationships", []) + if isinstance(relationships, list): + for rel_data in relationships: + if not isinstance(rel_data, dict): + continue + try: + result.add_relationship( + source_node=rel_data.get("source_node", ""), + target_node=rel_data.get("target_node", ""), + relationship_type=rel_data.get("relationship_type", ""), + provider=rel_data.get("provider", self.name), + confidence=float(rel_data.get("confidence", 0.8)), + raw_data=rel_data.get("raw_data", {}) + ) + except (ValueError, TypeError) as e: + self.logger.logger.warning(f"Skipping invalid relationship in cache: {e}") + continue - # Reconstruct attributes - for attr_data in cache_content.get("attributes", []): - result.add_attribute( - target_node=attr_data["target_node"], - name=attr_data["name"], - value=attr_data["value"], - attr_type=attr_data["type"], - provider=attr_data["provider"], - confidence=attr_data["confidence"], - metadata=attr_data.get("metadata", {}) - ) + # Reconstruct attributes with validation + attributes = cache_content.get("attributes", []) + if isinstance(attributes, list): + for attr_data in attributes: + if not isinstance(attr_data, dict): + continue + try: + result.add_attribute( + target_node=attr_data.get("target_node", ""), + name=attr_data.get("name", ""), + value=attr_data.get("value"), + attr_type=attr_data.get("type", "unknown"), + provider=attr_data.get("provider", self.name), + confidence=float(attr_data.get("confidence", 0.9)), + metadata=attr_data.get("metadata", {}) + ) + except (ValueError, TypeError) as e: + self.logger.logger.warning(f"Skipping invalid attribute in cache: {e}") + continue return result - except (json.JSONDecodeError, FileNotFoundError, KeyError) as e: - self.logger.logger.error(f"Failed to load cached certificates from {cache_file_path}: {e}") + except (json.JSONDecodeError, OSError, PermissionError) as e: + self.logger.logger.warning(f"Failed to load cache from {cache_file_path}: {e}") + return ProviderResult() + except Exception as e: + self.logger.logger.error(f"Unexpected error loading cache from {cache_file_path}: {e}") return ProviderResult() - def _load_raw_data_from_cache(self, cache_file_path: Path) -> List[Dict[str, Any]]: - """Load only the raw certificate data from a cache file.""" - try: - with open(cache_file_path, 'r') as f: - cache_content = json.load(f) - return cache_content.get("raw_certificates", []) - except (json.JSONDecodeError, FileNotFoundError): - return [] - def _save_result_to_cache(self, cache_file_path: Path, result: ProviderResult, raw_certificates: List[Dict[str, Any]], domain: str) -> None: - """Save processed crt.sh result and raw data to a cache file.""" + """FIXED: More robust cache saving with atomic writes.""" try: cache_data = { "domain": domain, "last_upstream_query": datetime.now(timezone.utc).isoformat(), - "raw_certificates": raw_certificates, # Store the raw data for deduplication + "raw_certificates": raw_certificates, "relationships": [ { "source_node": rel.source_node, @@ -246,34 +310,68 @@ class CrtShProvider(BaseProvider): } for attr in result.attributes ] } + cache_file_path.parent.mkdir(parents=True, exist_ok=True) - with open(cache_file_path, 'w') as f: - json.dump(cache_data, f, separators=(',', ':'), default=str) + + # FIXED: Atomic write using temporary file + temp_file = cache_file_path.with_suffix('.tmp') + try: + with open(temp_file, 'w', encoding='utf-8') as f: + json.dump(cache_data, f, separators=(',', ':'), default=str, ensure_ascii=False) + + # Atomic rename + temp_file.replace(cache_file_path) + self.logger.logger.debug(f"Saved cache for {domain} ({len(result.relationships)} relationships)") + + except Exception as e: + # Clean up temp file on error + if temp_file.exists(): + try: + temp_file.unlink() + except Exception: + pass + raise e + except Exception as e: self.logger.logger.warning(f"Failed to save cache file for {domain}: {e}") def _query_crtsh_api(self, domain: str) -> List[Dict[str, Any]]: - """Query crt.sh API for raw certificate data.""" + """FIXED: More robust API querying with better error handling.""" url = f"{self.base_url}?q={quote(domain)}&output=json" - response = self.make_request(url, target_indicator=domain) - - if not response or response.status_code != 200: - raise requests.exceptions.RequestException(f"crt.sh API returned status {response.status_code if response else 'None'}") try: - certificates = response.json() - except json.JSONDecodeError: - self.logger.logger.error(f"crt.sh returned invalid JSON for {domain}") - return [] + response = self.make_request(url, target_indicator=domain) + + if not response: + self.logger.logger.warning(f"No response from crt.sh for {domain}") + return [] + + if response.status_code != 200: + self.logger.logger.warning(f"crt.sh returned status {response.status_code} for {domain}") + return [] + + # FIXED: Better JSON parsing with error handling + try: + certificates = response.json() + except json.JSONDecodeError as e: + self.logger.logger.error(f"crt.sh returned invalid JSON for {domain}: {e}") + return [] - if not certificates: - return [] - - return certificates + if not certificates or not isinstance(certificates, list): + self.logger.logger.debug(f"crt.sh returned no certificates for {domain}") + return [] + + self.logger.logger.debug(f"crt.sh returned {len(certificates)} certificates for {domain}") + return certificates + + except Exception as e: + self.logger.logger.error(f"Error querying crt.sh API for {domain}: {e}") + raise e def _process_certificates_to_result_fixed(self, query_domain: str, certificates: List[Dict[str, Any]]) -> ProviderResult: """ Process certificates to create proper domain and CA nodes. + FIXED: Better error handling and progress tracking. """ result = ProviderResult() @@ -281,6 +379,11 @@ class CrtShProvider(BaseProvider): self.logger.logger.info(f"CrtSh processing cancelled before processing for domain: {query_domain}") return result + if not certificates: + self.logger.logger.debug(f"No certificates to process for {query_domain}") + return result + + # Check for incomplete data warning incompleteness_warning = self._check_for_incomplete_data(query_domain, certificates) if incompleteness_warning: result.add_attribute( @@ -294,54 +397,68 @@ class CrtShProvider(BaseProvider): all_discovered_domains = set() processed_issuers = set() + processed_certs = 0 for i, cert_data in enumerate(certificates): - # Check for stop event inside the loop to make it responsive. - if i % 10 == 0 and self._stop_event and self._stop_event.is_set(): - self.logger.logger.info(f"CrtSh processing cancelled at certificate {i} for domain: {query_domain}") - break + # FIXED: More frequent stop checks and progress logging + if i % 5 == 0: + if self._stop_event and self._stop_event.is_set(): + self.logger.logger.info(f"CrtSh processing cancelled at certificate {i}/{len(certificates)} for domain: {query_domain}") + break + + if i > 0 and i % 100 == 0: + self.logger.logger.debug(f"Processed {i}/{len(certificates)} certificates for {query_domain}") - # Extract all domains from this certificate - cert_domains = self._extract_domains_from_certificate(cert_data) - all_discovered_domains.update(cert_domains) + try: + # Extract all domains from this certificate + cert_domains = self._extract_domains_from_certificate(cert_data) + if cert_domains: + all_discovered_domains.update(cert_domains) - # Create CA nodes for certificate issuers - issuer_name = self._parse_issuer_organization(cert_data.get('issuer_name', '')) - if issuer_name and issuer_name not in processed_issuers: - result.add_relationship( - source_node=query_domain, - target_node=issuer_name, - relationship_type='crtsh_cert_issuer', - provider=self.name, - confidence=0.95, - raw_data={'issuer_dn': cert_data.get('issuer_name', '')} - ) - processed_issuers.add(issuer_name) + # Create CA nodes for certificate issuers + issuer_name = self._parse_issuer_organization(cert_data.get('issuer_name', '')) + if issuer_name and issuer_name not in processed_issuers: + result.add_relationship( + source_node=query_domain, + target_node=issuer_name, + relationship_type='crtsh_cert_issuer', + provider=self.name, + confidence=0.95, + raw_data={'issuer_dn': cert_data.get('issuer_name', '')} + ) + processed_issuers.add(issuer_name) - # Add certificate metadata to each domain in this certificate - cert_metadata = self._extract_certificate_metadata(cert_data) - for cert_domain in cert_domains: - if not _is_valid_domain(cert_domain): - continue + # Add certificate metadata to each domain in this certificate + cert_metadata = self._extract_certificate_metadata(cert_data) + for cert_domain in cert_domains: + if not _is_valid_domain(cert_domain): + continue - for key, value in cert_metadata.items(): - if value is not None: - result.add_attribute( - target_node=cert_domain, - name=f"cert_{key}", - value=value, - attr_type='certificate_data', - provider=self.name, - confidence=0.9, - metadata={'certificate_id': cert_data.get('id')} - ) + for key, value in cert_metadata.items(): + if value is not None: + result.add_attribute( + target_node=cert_domain, + name=f"cert_{key}", + value=value, + attr_type='certificate_data', + provider=self.name, + confidence=0.9, + metadata={'certificate_id': cert_data.get('id')} + ) + + processed_certs += 1 + + except Exception as e: + self.logger.logger.warning(f"Error processing certificate {i} for {query_domain}: {e}") + continue - # Check for stop event before creating final relationships. + # Check for stop event before creating final relationships if self._stop_event and self._stop_event.is_set(): self.logger.logger.info(f"CrtSh query cancelled before relationship creation for domain: {query_domain}") return result # Create selective relationships to avoid large entities + relationships_created = 0 for discovered_domain in all_discovered_domains: if discovered_domain == query_domain: continue @@ -371,10 +488,12 @@ class CrtShProvider(BaseProvider): raw_data={'relationship_type': 'certificate_discovery'}, discovery_method="certificate_transparency_analysis" ) + relationships_created += 1 - self.logger.logger.info(f"CrtSh processing completed for {query_domain}: {len(all_discovered_domains)} domains, {result.get_relationship_count()} relationships") + self.logger.logger.info(f"CrtSh processing completed for {query_domain}: processed {processed_certs}/{len(certificates)} certificates, {len(all_discovered_domains)} domains, {relationships_created} relationships") return result + # [Rest of the methods remain the same as in the original file] def _should_create_relationship(self, source_domain: str, target_domain: str) -> bool: """ Determine if a relationship should be created between two domains.