145 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			145 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# dnsrecon-reduced/core/rate_limiter.py
 | 
						|
 | 
						|
import time
 | 
						|
import logging
 | 
						|
 | 
						|
class GlobalRateLimiter:
 | 
						|
    """
 | 
						|
    FIXED: Improved rate limiter with better cleanup and error handling.
 | 
						|
    Prevents accumulation of stale entries that cause infinite retry loops.
 | 
						|
    """
 | 
						|
    
 | 
						|
    def __init__(self, redis_client):
 | 
						|
        self.redis = redis_client
 | 
						|
        self.logger = logging.getLogger('dnsrecon.rate_limiter')
 | 
						|
        # Track last cleanup times to avoid excessive Redis operations
 | 
						|
        self._last_cleanup = {}
 | 
						|
 | 
						|
    def is_rate_limited(self, key, limit, period):
 | 
						|
        """
 | 
						|
        FIXED: Check if a key is rate-limited with improved cleanup and error handling.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            key: Rate limit key (e.g., provider name)
 | 
						|
            limit: Maximum requests allowed
 | 
						|
            period: Time period in seconds (60 for per-minute)
 | 
						|
            
 | 
						|
        Returns:
 | 
						|
            bool: True if rate limited, False otherwise
 | 
						|
        """
 | 
						|
        if limit <= 0:
 | 
						|
            # Rate limit of 0 or negative means no limiting
 | 
						|
            return False
 | 
						|
            
 | 
						|
        now = time.time()
 | 
						|
        rate_key = f"rate_limit:{key}"
 | 
						|
        
 | 
						|
        try:
 | 
						|
            # FIXED: More aggressive cleanup to prevent accumulation
 | 
						|
            # Only clean up if we haven't cleaned recently (every 10 seconds max)
 | 
						|
            should_cleanup = (
 | 
						|
                rate_key not in self._last_cleanup or 
 | 
						|
                now - self._last_cleanup.get(rate_key, 0) > 10
 | 
						|
            )
 | 
						|
            
 | 
						|
            if should_cleanup:
 | 
						|
                # Remove entries older than the period
 | 
						|
                removed_count = self.redis.zremrangebyscore(rate_key, 0, now - period)
 | 
						|
                self._last_cleanup[rate_key] = now
 | 
						|
                
 | 
						|
                if removed_count > 0:
 | 
						|
                    self.logger.debug(f"Rate limiter cleaned up {removed_count} old entries for {key}")
 | 
						|
            
 | 
						|
            # Get current count
 | 
						|
            current_count = self.redis.zcard(rate_key)
 | 
						|
            
 | 
						|
            if current_count >= limit:
 | 
						|
                self.logger.debug(f"Rate limited: {key} has {current_count}/{limit} requests in period")
 | 
						|
                return True
 | 
						|
            
 | 
						|
            # Add new timestamp with error handling
 | 
						|
            try:
 | 
						|
                # Use pipeline for atomic operations
 | 
						|
                pipe = self.redis.pipeline()
 | 
						|
                pipe.zadd(rate_key, {str(now): now})
 | 
						|
                pipe.expire(rate_key, int(period * 2))  # Set TTL to 2x period for safety
 | 
						|
                pipe.execute()
 | 
						|
            except Exception as e:
 | 
						|
                self.logger.warning(f"Failed to record rate limit entry for {key}: {e}")
 | 
						|
                # Don't block the request if we can't record it
 | 
						|
                return False
 | 
						|
            
 | 
						|
            return False
 | 
						|
            
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.error(f"Rate limiter error for {key}: {e}")
 | 
						|
            # FIXED: On Redis errors, don't block requests to avoid infinite loops
 | 
						|
            return False
 | 
						|
    
 | 
						|
    def get_rate_limit_status(self, key, limit, period):
 | 
						|
        """
 | 
						|
        Get detailed rate limit status for debugging.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            dict: Status information including current count, limit, and time to reset
 | 
						|
        """
 | 
						|
        now = time.time()
 | 
						|
        rate_key = f"rate_limit:{key}"
 | 
						|
        
 | 
						|
        try:
 | 
						|
            current_count = self.redis.zcard(rate_key)
 | 
						|
            
 | 
						|
            # Get oldest entry to calculate reset time
 | 
						|
            oldest_entries = self.redis.zrange(rate_key, 0, 0, withscores=True)
 | 
						|
            time_to_reset = 0
 | 
						|
            if oldest_entries:
 | 
						|
                oldest_time = oldest_entries[0][1]
 | 
						|
                time_to_reset = max(0, period - (now - oldest_time))
 | 
						|
            
 | 
						|
            return {
 | 
						|
                'key': key,
 | 
						|
                'current_count': current_count,
 | 
						|
                'limit': limit,
 | 
						|
                'period': period,
 | 
						|
                'is_limited': current_count >= limit,
 | 
						|
                'time_to_reset': time_to_reset
 | 
						|
            }
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.error(f"Failed to get rate limit status for {key}: {e}")
 | 
						|
            return {
 | 
						|
                'key': key,
 | 
						|
                'current_count': 0,
 | 
						|
                'limit': limit,
 | 
						|
                'period': period,
 | 
						|
                'is_limited': False,
 | 
						|
                'time_to_reset': 0,
 | 
						|
                'error': str(e)
 | 
						|
            }
 | 
						|
    
 | 
						|
    def reset_rate_limit(self, key):
 | 
						|
        """
 | 
						|
        ADDED: Reset rate limit for a specific key (useful for debugging).
 | 
						|
        """
 | 
						|
        rate_key = f"rate_limit:{key}"
 | 
						|
        try:
 | 
						|
            deleted = self.redis.delete(rate_key)
 | 
						|
            self.logger.info(f"Reset rate limit for {key} (deleted: {deleted})")
 | 
						|
            return True
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.error(f"Failed to reset rate limit for {key}: {e}")
 | 
						|
            return False
 | 
						|
    
 | 
						|
    def cleanup_all_rate_limits(self):
 | 
						|
        """
 | 
						|
        ADDED: Clean up all rate limit entries (useful for maintenance).
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            keys = self.redis.keys("rate_limit:*")
 | 
						|
            if keys:
 | 
						|
                deleted = self.redis.delete(*keys)
 | 
						|
                self.logger.info(f"Cleaned up {deleted} rate limit keys")
 | 
						|
                return deleted
 | 
						|
            return 0
 | 
						|
        except Exception as e:
 | 
						|
            self.logger.error(f"Failed to cleanup rate limits: {e}")
 | 
						|
            return 0 |