460 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			460 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# dnsrecon/providers/base_provider.py
 | 
						|
 | 
						|
import time
 | 
						|
import requests
 | 
						|
import threading
 | 
						|
import os
 | 
						|
import json
 | 
						|
from abc import ABC, abstractmethod
 | 
						|
from typing import List, Dict, Any, Optional, Tuple
 | 
						|
 | 
						|
from core.logger import get_forensic_logger
 | 
						|
 | 
						|
 | 
						|
class RateLimiter:
 | 
						|
    """Simple rate limiter for API calls."""
 | 
						|
 | 
						|
    def __init__(self, requests_per_minute: int):
 | 
						|
        """
 | 
						|
        Initialize rate limiter.
 | 
						|
 | 
						|
        Args:
 | 
						|
            requests_per_minute: Maximum requests allowed per minute
 | 
						|
        """
 | 
						|
        self.requests_per_minute = requests_per_minute
 | 
						|
        self.min_interval = 60.0 / requests_per_minute
 | 
						|
        self.last_request_time = 0
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        """RateLimiter is fully picklable, return full state."""
 | 
						|
        return self.__dict__.copy()
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        """Restore RateLimiter state."""
 | 
						|
        self.__dict__.update(state)
 | 
						|
 | 
						|
    def wait_if_needed(self) -> None:
 | 
						|
        """Wait if necessary to respect rate limits."""
 | 
						|
        current_time = time.time()
 | 
						|
        time_since_last = current_time - self.last_request_time
 | 
						|
 | 
						|
        if time_since_last < self.min_interval:
 | 
						|
            sleep_time = self.min_interval - time_since_last
 | 
						|
            time.sleep(sleep_time)
 | 
						|
 | 
						|
        self.last_request_time = time.time()
 | 
						|
 | 
						|
 | 
						|
class BaseProvider(ABC):
 | 
						|
    """
 | 
						|
    Abstract base class for all DNSRecon data providers.
 | 
						|
    Now supports session-specific configuration.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
 | 
						|
        """
 | 
						|
        Initialize base provider with session-specific configuration.
 | 
						|
 | 
						|
        Args:
 | 
						|
            name: Provider name for logging
 | 
						|
            rate_limit: Requests per minute limit (default override)
 | 
						|
            timeout: Request timeout in seconds
 | 
						|
            session_config: Session-specific configuration
 | 
						|
        """
 | 
						|
        # Use session config if provided, otherwise fall back to global config
 | 
						|
        if session_config is not None:
 | 
						|
            self.config = session_config
 | 
						|
            actual_rate_limit = self.config.get_rate_limit(name)
 | 
						|
            actual_timeout = self.config.default_timeout
 | 
						|
        else:
 | 
						|
            # Fallback to global config for backwards compatibility
 | 
						|
            from config import config as global_config
 | 
						|
            self.config = global_config
 | 
						|
            actual_rate_limit = rate_limit
 | 
						|
            actual_timeout = timeout
 | 
						|
 | 
						|
        self.name = name
 | 
						|
        self.rate_limiter = RateLimiter(actual_rate_limit)
 | 
						|
        self.timeout = actual_timeout
 | 
						|
        self._local = threading.local()
 | 
						|
        self.logger = get_forensic_logger()
 | 
						|
        self._stop_event = None
 | 
						|
 | 
						|
        # Caching configuration (per session)
 | 
						|
        self.cache_dir = f'.cache/{id(self.config)}'  # Unique cache per session config
 | 
						|
        self.cache_expiry = self.config.cache_expiry_hours * 3600
 | 
						|
        if not os.path.exists(self.cache_dir):
 | 
						|
            os.makedirs(self.cache_dir)
 | 
						|
 | 
						|
        # Statistics (per provider instance)
 | 
						|
        self.total_requests = 0
 | 
						|
        self.successful_requests = 0
 | 
						|
        self.failed_requests = 0
 | 
						|
        self.total_relationships_found = 0
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        """Prepare BaseProvider for pickling by excluding unpicklable objects."""
 | 
						|
        state = self.__dict__.copy()
 | 
						|
        # Exclude the unpickleable '_local' attribute and stop event
 | 
						|
        unpicklable_attrs = ['_local', '_stop_event']
 | 
						|
        for attr in unpicklable_attrs:
 | 
						|
            if attr in state:
 | 
						|
                del state[attr]
 | 
						|
        return state
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        """Restore BaseProvider after unpickling by reconstructing threading objects."""
 | 
						|
        self.__dict__.update(state)
 | 
						|
        # Re-initialize the '_local' attribute and stop event
 | 
						|
        self._local = threading.local()
 | 
						|
        self._stop_event = None
 | 
						|
 | 
						|
    @property
 | 
						|
    def session(self):
 | 
						|
        if not hasattr(self._local, 'session'):
 | 
						|
            self._local.session = requests.Session()
 | 
						|
            self._local.session.headers.update({
 | 
						|
                'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
 | 
						|
            })
 | 
						|
        return self._local.session
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_name(self) -> str:
 | 
						|
        """Return the provider name."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_display_name(self) -> str:
 | 
						|
        """Return the provider display name for the UI."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def requires_api_key(self) -> bool:
 | 
						|
        """Return True if the provider requires an API key."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_eligibility(self) -> Dict[str, bool]:
 | 
						|
        """Return a dictionary indicating if the provider can query domains and/or IPs."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def is_available(self) -> bool:
 | 
						|
        """Check if the provider is available and properly configured."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def query_domain(self, domain: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
 | 
						|
        """
 | 
						|
        Query the provider for information about a domain.
 | 
						|
 | 
						|
        Args:
 | 
						|
            domain: Domain to investigate
 | 
						|
 | 
						|
        Returns:
 | 
						|
            List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def query_ip(self, ip: str) -> List[Tuple[str, str, str, float, Dict[str, Any]]]:
 | 
						|
        """
 | 
						|
        Query the provider for information about an IP address.
 | 
						|
 | 
						|
        Args:
 | 
						|
            ip: IP address to investigate
 | 
						|
 | 
						|
        Returns:
 | 
						|
            List of tuples: (source_node, target_node, relationship_type, confidence, raw_data)
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    def make_request(self, url: str, method: str = "GET",
 | 
						|
                    params: Optional[Dict[str, Any]] = None,
 | 
						|
                    headers: Optional[Dict[str, str]] = None,
 | 
						|
                    target_indicator: str = "",
 | 
						|
                    max_retries: int = 3) -> Optional[requests.Response]:
 | 
						|
        """
 | 
						|
        Make a rate-limited HTTP request with aggressive stop signal handling.
 | 
						|
        Terminates immediately when stop is requested, including during retries.
 | 
						|
        """
 | 
						|
        # Check for cancellation before starting
 | 
						|
        if self._is_stop_requested():
 | 
						|
            print(f"Request cancelled before start: {url}")
 | 
						|
            return None
 | 
						|
 | 
						|
        # Create a unique cache key
 | 
						|
        cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
 | 
						|
        cache_path = os.path.join(self.cache_dir, cache_key)
 | 
						|
 | 
						|
        # Check cache
 | 
						|
        if os.path.exists(cache_path):
 | 
						|
            cache_age = time.time() - os.path.getmtime(cache_path)
 | 
						|
            if cache_age < self.cache_expiry:
 | 
						|
                print(f"Returning cached response for: {url}")
 | 
						|
                with open(cache_path, 'r') as f:
 | 
						|
                    cached_data = json.load(f)
 | 
						|
                    response = requests.Response()
 | 
						|
                    response.status_code = cached_data['status_code']
 | 
						|
                    response._content = cached_data['content'].encode('utf-8')
 | 
						|
                    response.headers = cached_data['headers']
 | 
						|
                    return response
 | 
						|
 | 
						|
        # Determine effective max_retries based on stop signal
 | 
						|
        effective_max_retries = 0 if self._is_stop_requested() else max_retries
 | 
						|
        last_exception = None
 | 
						|
 | 
						|
        for attempt in range(effective_max_retries + 1):
 | 
						|
            # AGGRESSIVE: Check for cancellation before each attempt
 | 
						|
            if self._is_stop_requested():
 | 
						|
                print(f"Request cancelled during attempt {attempt + 1}: {url}")
 | 
						|
                return None
 | 
						|
 | 
						|
            # Apply rate limiting with cancellation awareness
 | 
						|
            if not self._wait_with_cancellation_check():
 | 
						|
                print(f"Request cancelled during rate limiting: {url}")
 | 
						|
                return None
 | 
						|
 | 
						|
            # AGGRESSIVE: Final check before making HTTP request
 | 
						|
            if self._is_stop_requested():
 | 
						|
                print(f"Request cancelled before HTTP call: {url}")
 | 
						|
                return None
 | 
						|
 | 
						|
            start_time = time.time()
 | 
						|
            response = None
 | 
						|
            error = None
 | 
						|
 | 
						|
            try:
 | 
						|
                self.total_requests += 1
 | 
						|
 | 
						|
                # Prepare request
 | 
						|
                request_headers = self.session.headers.copy()
 | 
						|
                if headers:
 | 
						|
                    request_headers.update(headers)
 | 
						|
 | 
						|
                print(f"Making {method} request to: {url} (attempt {attempt + 1})")
 | 
						|
 | 
						|
                # AGGRESSIVE: Use much shorter timeout if termination is requested
 | 
						|
                request_timeout = self.timeout
 | 
						|
                if self._is_stop_requested():
 | 
						|
                    request_timeout = 2  # Max 2 seconds if termination requested
 | 
						|
                    print(f"Stop requested - using short timeout: {request_timeout}s")
 | 
						|
 | 
						|
                # Make request
 | 
						|
                if method.upper() == "GET":
 | 
						|
                    response = self.session.get(
 | 
						|
                        url,
 | 
						|
                        params=params,
 | 
						|
                        headers=request_headers,
 | 
						|
                        timeout=request_timeout
 | 
						|
                    )
 | 
						|
                elif method.upper() == "POST":
 | 
						|
                    response = self.session.post(
 | 
						|
                        url,
 | 
						|
                        json=params,
 | 
						|
                        headers=request_headers,
 | 
						|
                        timeout=request_timeout
 | 
						|
                    )
 | 
						|
                else:
 | 
						|
                    raise ValueError(f"Unsupported HTTP method: {method}")
 | 
						|
 | 
						|
                print(f"Response status: {response.status_code}")
 | 
						|
                response.raise_for_status()
 | 
						|
                self.successful_requests += 1
 | 
						|
                
 | 
						|
                # Success - log, cache, and return
 | 
						|
                duration_ms = (time.time() - start_time) * 1000
 | 
						|
                self.logger.log_api_request(
 | 
						|
                    provider=self.name,
 | 
						|
                    url=url,
 | 
						|
                    method=method.upper(),
 | 
						|
                    status_code=response.status_code,
 | 
						|
                    response_size=len(response.content),
 | 
						|
                    duration_ms=duration_ms,
 | 
						|
                    error=None,
 | 
						|
                    target_indicator=target_indicator
 | 
						|
                )
 | 
						|
                # Cache the successful response to disk
 | 
						|
                with open(cache_path, 'w') as f:
 | 
						|
                    json.dump({
 | 
						|
                        'status_code': response.status_code,
 | 
						|
                        'content': response.text,
 | 
						|
                        'headers': dict(response.headers)
 | 
						|
                    }, f)
 | 
						|
                return response
 | 
						|
 | 
						|
            except requests.exceptions.RequestException as e:
 | 
						|
                error = str(e)
 | 
						|
                self.failed_requests += 1
 | 
						|
                print(f"Request failed (attempt {attempt + 1}): {error}")
 | 
						|
                last_exception = e
 | 
						|
                
 | 
						|
                # AGGRESSIVE: Immediately abort retries if stop requested
 | 
						|
                if self._is_stop_requested():
 | 
						|
                    print(f"Stop requested - aborting retries for: {url}")
 | 
						|
                    break
 | 
						|
                
 | 
						|
                # Check if we should retry (but only if stop not requested)
 | 
						|
                if attempt < effective_max_retries and self._should_retry(e):
 | 
						|
                    # Use a longer, more respectful backoff for 429 errors
 | 
						|
                    if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429:
 | 
						|
                        # Start with a 10-second backoff and increase exponentially
 | 
						|
                        backoff_time = 10 * (2 ** attempt)
 | 
						|
                        print(f"Rate limit hit. Retrying in {backoff_time} seconds...")
 | 
						|
                    else:
 | 
						|
                        backoff_time = min(1.0, (2 ** attempt) * 0.5)  # Shorter backoff for other errors
 | 
						|
                        print(f"Retrying in {backoff_time} seconds...")
 | 
						|
                    
 | 
						|
                    # AGGRESSIVE: Much shorter backoff and more frequent checking
 | 
						|
                    if not self._sleep_with_cancellation_check(backoff_time):
 | 
						|
                        print(f"Stop requested during backoff - aborting: {url}")
 | 
						|
                        return None
 | 
						|
                    continue
 | 
						|
                else:
 | 
						|
                    break
 | 
						|
 | 
						|
            except Exception as e:
 | 
						|
                error = f"Unexpected error: {str(e)}"
 | 
						|
                self.failed_requests += 1
 | 
						|
                print(f"Unexpected error: {error}")
 | 
						|
                last_exception = e
 | 
						|
                break
 | 
						|
 | 
						|
        # All attempts failed - log and return None
 | 
						|
        duration_ms = (time.time() - start_time) * 1000
 | 
						|
        self.logger.log_api_request(
 | 
						|
            provider=self.name,
 | 
						|
            url=url,
 | 
						|
            method=method.upper(),
 | 
						|
            status_code=response.status_code if response else None,
 | 
						|
            response_size=len(response.content) if response else None,
 | 
						|
            duration_ms=duration_ms,
 | 
						|
            error=error,
 | 
						|
            target_indicator=target_indicator
 | 
						|
        )
 | 
						|
        
 | 
						|
        if error and last_exception:
 | 
						|
            raise last_exception
 | 
						|
        
 | 
						|
        return None
 | 
						|
 | 
						|
    def _is_stop_requested(self) -> bool:
 | 
						|
        """
 | 
						|
        Enhanced stop signal checking that handles both local and Redis-based signals.
 | 
						|
        """
 | 
						|
        if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
 | 
						|
            return True
 | 
						|
        return False
 | 
						|
 | 
						|
    def _wait_with_cancellation_check(self) -> bool:
 | 
						|
        """
 | 
						|
        Wait for rate limiting while aggressively checking for cancellation.
 | 
						|
        Returns False if cancelled during wait.
 | 
						|
        """
 | 
						|
        current_time = time.time()
 | 
						|
        time_since_last = current_time - self.rate_limiter.last_request_time
 | 
						|
 | 
						|
        if time_since_last < self.rate_limiter.min_interval:
 | 
						|
            sleep_time = self.rate_limiter.min_interval - time_since_last
 | 
						|
            if not self._sleep_with_cancellation_check(sleep_time):
 | 
						|
                return False
 | 
						|
 | 
						|
        self.rate_limiter.last_request_time = time.time()
 | 
						|
        return True
 | 
						|
 | 
						|
    def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
 | 
						|
        """
 | 
						|
        Sleep for the specified time while aggressively checking for cancellation.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            sleep_time: Time to sleep in seconds
 | 
						|
            
 | 
						|
        Returns:
 | 
						|
            bool: True if sleep completed, False if cancelled
 | 
						|
        """
 | 
						|
        sleep_start = time.time()
 | 
						|
        check_interval = 0.05  # Check every 50ms for aggressive responsiveness
 | 
						|
        
 | 
						|
        while time.time() - sleep_start < sleep_time:
 | 
						|
            if self._is_stop_requested():
 | 
						|
                return False
 | 
						|
            remaining_time = sleep_time - (time.time() - sleep_start)
 | 
						|
            time.sleep(min(check_interval, remaining_time))
 | 
						|
        
 | 
						|
        return True
 | 
						|
 | 
						|
    def set_stop_event(self, stop_event: threading.Event) -> None:
 | 
						|
        """
 | 
						|
        Set the stop event for this provider to enable cancellation.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            stop_event: Threading event to signal cancellation
 | 
						|
        """
 | 
						|
        self._stop_event = stop_event
 | 
						|
 | 
						|
    def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
 | 
						|
        """
 | 
						|
        Determine if a request should be retried based on the exception.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            exception: The request exception that occurred
 | 
						|
            
 | 
						|
        Returns:
 | 
						|
            True if the request should be retried
 | 
						|
        """
 | 
						|
        # Retry on connection errors and timeouts
 | 
						|
        if isinstance(exception, (requests.exceptions.ConnectionError, 
 | 
						|
                                requests.exceptions.Timeout)):
 | 
						|
            return True
 | 
						|
        
 | 
						|
        if isinstance(exception, requests.exceptions.HTTPError):
 | 
						|
            if hasattr(exception, 'response') and exception.response:
 | 
						|
                # Retry on server errors (5xx) AND on rate-limiting errors (429)
 | 
						|
                return exception.response.status_code >= 500 or exception.response.status_code == 429
 | 
						|
        
 | 
						|
        return False
 | 
						|
 | 
						|
    def log_relationship_discovery(self, source_node: str, target_node: str,
 | 
						|
                                 relationship_type: str,
 | 
						|
                                 confidence_score: float,
 | 
						|
                                 raw_data: Dict[str, Any],
 | 
						|
                                 discovery_method: str) -> None:
 | 
						|
        """
 | 
						|
        Log discovery of a new relationship.
 | 
						|
 | 
						|
        Args:
 | 
						|
            source_node: Source node identifier
 | 
						|
            target_node: Target node identifier
 | 
						|
            relationship_type: Type of relationship
 | 
						|
            confidence_score: Confidence score
 | 
						|
            raw_data: Raw data from provider
 | 
						|
            discovery_method: Method used for discovery
 | 
						|
        """
 | 
						|
        self.total_relationships_found += 1
 | 
						|
 | 
						|
        self.logger.log_relationship_discovery(
 | 
						|
            source_node=source_node,
 | 
						|
            target_node=target_node,
 | 
						|
            relationship_type=relationship_type,
 | 
						|
            confidence_score=confidence_score,
 | 
						|
            provider=self.name,
 | 
						|
            raw_data=raw_data,
 | 
						|
            discovery_method=discovery_method
 | 
						|
        )
 | 
						|
 | 
						|
    def get_statistics(self) -> Dict[str, Any]:
 | 
						|
        """
 | 
						|
        Get provider statistics.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Dictionary containing provider performance metrics
 | 
						|
        """
 | 
						|
        return {
 | 
						|
            'name': self.name,
 | 
						|
            'total_requests': self.total_requests,
 | 
						|
            'successful_requests': self.successful_requests,
 | 
						|
            'failed_requests': self.failed_requests,
 | 
						|
            'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
 | 
						|
            'relationships_found': self.total_relationships_found,
 | 
						|
            'rate_limit': self.rate_limiter.requests_per_minute
 | 
						|
        } |