273 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			273 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# dnsrecon/providers/base_provider.py
 | 
						|
 | 
						|
import time
 | 
						|
import requests
 | 
						|
import threading
 | 
						|
from abc import ABC, abstractmethod
 | 
						|
from typing import Dict, Any, Optional
 | 
						|
 | 
						|
from core.logger import get_forensic_logger
 | 
						|
from core.rate_limiter import GlobalRateLimiter
 | 
						|
from core.provider_result import ProviderResult
 | 
						|
 | 
						|
 | 
						|
class BaseProvider(ABC):
 | 
						|
    """
 | 
						|
    Abstract base class for all DNSRecon data providers.
 | 
						|
    Now supports session-specific configuration and returns standardized ProviderResult objects.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
 | 
						|
        """
 | 
						|
        Initialize base provider with session-specific configuration.
 | 
						|
 | 
						|
        Args:
 | 
						|
            name: Provider name for logging
 | 
						|
            rate_limit: Requests per minute limit (default override)
 | 
						|
            timeout: Request timeout in seconds
 | 
						|
            session_config: Session-specific configuration
 | 
						|
        """
 | 
						|
        # Use session config if provided, otherwise fall back to global config
 | 
						|
        if session_config is not None:
 | 
						|
            self.config = session_config
 | 
						|
            actual_rate_limit = self.config.get_rate_limit(name)
 | 
						|
            actual_timeout = self.config.default_timeout
 | 
						|
        else:
 | 
						|
            # Fallback to global config for backwards compatibility
 | 
						|
            from config import config as global_config
 | 
						|
            self.config = global_config
 | 
						|
            actual_timeout = timeout
 | 
						|
 | 
						|
        self.name = name
 | 
						|
        self.timeout = actual_timeout
 | 
						|
        self._local = threading.local()
 | 
						|
        self.logger = get_forensic_logger()
 | 
						|
        self._stop_event = None
 | 
						|
 | 
						|
        # Statistics (per provider instance)
 | 
						|
        self.total_requests = 0
 | 
						|
        self.successful_requests = 0
 | 
						|
        self.failed_requests = 0
 | 
						|
        self.total_relationships_found = 0
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        """Prepare BaseProvider for pickling by excluding unpicklable objects."""
 | 
						|
        state = self.__dict__.copy()
 | 
						|
        # Exclude the unpickleable '_local' attribute (which holds the session) and stop event
 | 
						|
        unpicklable_attrs = ['_local', '_stop_event']
 | 
						|
        for attr in unpicklable_attrs:
 | 
						|
            if attr in state:
 | 
						|
                del state[attr]
 | 
						|
        return state
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        """Restore BaseProvider after unpickling by reconstructing threading objects."""
 | 
						|
        self.__dict__.update(state)
 | 
						|
        # Re-initialize the '_local' attribute and stop event
 | 
						|
        self._local = threading.local()
 | 
						|
        self._stop_event = None
 | 
						|
 | 
						|
    @property
 | 
						|
    def session(self):
 | 
						|
        if not hasattr(self._local, 'session'):
 | 
						|
            self._local.session = requests.Session()
 | 
						|
            self._local.session.headers.update({
 | 
						|
                'User-Agent': 'DNSRecon/1.0 (Passive Reconnaissance Tool)'
 | 
						|
            })
 | 
						|
        return self._local.session
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_name(self) -> str:
 | 
						|
        """Return the provider name."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_display_name(self) -> str:
 | 
						|
        """Return the provider display name for the UI."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def requires_api_key(self) -> bool:
 | 
						|
        """Return True if the provider requires an API key."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_eligibility(self) -> Dict[str, bool]:
 | 
						|
        """Return a dictionary indicating if the provider can query domains and/or IPs."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def is_available(self) -> bool:
 | 
						|
        """Check if the provider is available and properly configured."""
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def query_domain(self, domain: str) -> ProviderResult:
 | 
						|
        """
 | 
						|
        Query the provider for information about a domain.
 | 
						|
 | 
						|
        Args:
 | 
						|
            domain: Domain to investigate
 | 
						|
 | 
						|
        Returns:
 | 
						|
            ProviderResult containing standardized attributes and relationships
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def query_ip(self, ip: str) -> ProviderResult:
 | 
						|
        """
 | 
						|
        Query the provider for information about an IP address.
 | 
						|
 | 
						|
        Args:
 | 
						|
            ip: IP address to investigate
 | 
						|
 | 
						|
        Returns:
 | 
						|
            ProviderResult containing standardized attributes and relationships
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    def make_request(self, url: str, method: str = "GET",
 | 
						|
                    params: Optional[Dict[str, Any]] = None,
 | 
						|
                    headers: Optional[Dict[str, str]] = None,
 | 
						|
                    target_indicator: str = "") -> Optional[requests.Response]:
 | 
						|
        """
 | 
						|
        Make a rate-limited HTTP request.
 | 
						|
        FIXED: Returns response without automatically raising HTTPError exceptions.
 | 
						|
        Individual providers should handle status codes appropriately.
 | 
						|
        """
 | 
						|
        if self._is_stop_requested():
 | 
						|
            print(f"Request cancelled before start: {url}")
 | 
						|
            return None
 | 
						|
 | 
						|
        start_time = time.time()
 | 
						|
        response = None
 | 
						|
        error = None
 | 
						|
 | 
						|
        try:
 | 
						|
            self.total_requests += 1
 | 
						|
 | 
						|
            request_headers = dict(self.session.headers).copy()
 | 
						|
            if headers:
 | 
						|
                request_headers.update(headers)
 | 
						|
 | 
						|
            print(f"Making {method} request to: {url}")
 | 
						|
 | 
						|
            if method.upper() == "GET":
 | 
						|
                response = self.session.get(
 | 
						|
                    url,
 | 
						|
                    params=params,
 | 
						|
                    headers=request_headers,
 | 
						|
                    timeout=self.timeout
 | 
						|
                )
 | 
						|
            elif method.upper() == "POST":
 | 
						|
                response = self.session.post(
 | 
						|
                    url,
 | 
						|
                    json=params,
 | 
						|
                    headers=request_headers,
 | 
						|
                    timeout=self.timeout
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                raise ValueError(f"Unsupported HTTP method: {method}")
 | 
						|
 | 
						|
            print(f"Response status: {response.status_code}")
 | 
						|
            
 | 
						|
            # FIXED: Don't automatically raise for HTTP error status codes
 | 
						|
            # Let individual providers handle status codes appropriately
 | 
						|
            # Only count 2xx responses as successful
 | 
						|
            if 200 <= response.status_code < 300:
 | 
						|
                self.successful_requests += 1
 | 
						|
            else:
 | 
						|
                self.failed_requests += 1
 | 
						|
            
 | 
						|
            duration_ms = (time.time() - start_time) * 1000
 | 
						|
            self.logger.log_api_request(
 | 
						|
                provider=self.name,
 | 
						|
                url=url,
 | 
						|
                method=method.upper(),
 | 
						|
                status_code=response.status_code,
 | 
						|
                response_size=len(response.content),
 | 
						|
                duration_ms=duration_ms,
 | 
						|
                error=None,
 | 
						|
                target_indicator=target_indicator
 | 
						|
            )
 | 
						|
            
 | 
						|
            return response
 | 
						|
 | 
						|
        except requests.exceptions.RequestException as e:
 | 
						|
            error = str(e)
 | 
						|
            self.failed_requests += 1
 | 
						|
            duration_ms = (time.time() - start_time) * 1000
 | 
						|
            self.logger.log_api_request(
 | 
						|
                provider=self.name,
 | 
						|
                url=url,
 | 
						|
                method=method.upper(),
 | 
						|
                status_code=response.status_code if response else None,
 | 
						|
                response_size=len(response.content) if response else None,
 | 
						|
                duration_ms=duration_ms,
 | 
						|
                error=error,
 | 
						|
                target_indicator=target_indicator
 | 
						|
            )
 | 
						|
            raise e
 | 
						|
 | 
						|
    def _is_stop_requested(self) -> bool:
 | 
						|
        """
 | 
						|
        Enhanced stop signal checking that handles both local and Redis-based signals.
 | 
						|
        """
 | 
						|
        if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
 | 
						|
            return True
 | 
						|
        return False
 | 
						|
 | 
						|
    def set_stop_event(self, stop_event: threading.Event) -> None:
 | 
						|
        """
 | 
						|
        Set the stop event for this provider to enable cancellation.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            stop_event: Threading event to signal cancellation
 | 
						|
        """
 | 
						|
        self._stop_event = stop_event
 | 
						|
 | 
						|
    def log_relationship_discovery(self, source_node: str, target_node: str,
 | 
						|
                                 relationship_type: str,
 | 
						|
                                 confidence_score: float,
 | 
						|
                                 raw_data: Dict[str, Any],
 | 
						|
                                 discovery_method: str) -> None:
 | 
						|
        """
 | 
						|
        Log discovery of a new relationship.
 | 
						|
 | 
						|
        Args:
 | 
						|
            source_node: Source node identifier
 | 
						|
            target_node: Target node identifier
 | 
						|
            relationship_type: Type of relationship
 | 
						|
            confidence_score: Confidence score
 | 
						|
            raw_data: Raw data from provider
 | 
						|
            discovery_method: Method used for discovery
 | 
						|
        """
 | 
						|
        self.total_relationships_found += 1
 | 
						|
 | 
						|
        self.logger.log_relationship_discovery(
 | 
						|
            source_node=source_node,
 | 
						|
            target_node=target_node,
 | 
						|
            relationship_type=relationship_type,
 | 
						|
            confidence_score=confidence_score,
 | 
						|
            provider=self.name,
 | 
						|
            raw_data=raw_data,
 | 
						|
            discovery_method=discovery_method
 | 
						|
        )
 | 
						|
 | 
						|
    def get_statistics(self) -> Dict[str, Any]:
 | 
						|
        """
 | 
						|
        Get provider statistics.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Dictionary containing provider performance metrics
 | 
						|
        """
 | 
						|
        return {
 | 
						|
            'name': self.name,
 | 
						|
            'total_requests': self.total_requests,
 | 
						|
            'successful_requests': self.successful_requests,
 | 
						|
            'failed_requests': self.failed_requests,
 | 
						|
            'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
 | 
						|
            'relationships_found': self.total_relationships_found,
 | 
						|
            'rate_limit': self.config.get_rate_limit(self.name)
 | 
						|
        } |