diff --git a/providers/base_provider.py b/providers/base_provider.py index 2bb086e..76ebd10 100644 --- a/providers/base_provider.py +++ b/providers/base_provider.py @@ -172,22 +172,17 @@ class BaseProvider(ABC): def make_request(self, url: str, method: str = "GET", params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, - target_indicator: str = "", - max_retries: int = 3) -> Optional[requests.Response]: + target_indicator: str = "") -> Optional[requests.Response]: """ - Make a rate-limited HTTP request with aggressive stop signal handling. - Terminates immediately when stop is requested, including during retries. + Make a rate-limited HTTP request. """ - # Check for cancellation before starting if self._is_stop_requested(): print(f"Request cancelled before start: {url}") return None - # Create a unique cache key cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" cache_path = os.path.join(self.cache_dir, cache_key) - # Check cache if os.path.exists(cache_path): cache_age = time.time() - os.path.getmtime(cache_path) if cache_age < self.cache_expiry: @@ -200,143 +195,76 @@ class BaseProvider(ABC): response.headers = cached_data['headers'] return response - # Determine effective max_retries based on stop signal - effective_max_retries = 0 if self._is_stop_requested() else max_retries - last_exception = None + self.rate_limiter.wait_if_needed() - for attempt in range(effective_max_retries + 1): - # AGGRESSIVE: Check for cancellation before each attempt - if self._is_stop_requested(): - print(f"Request cancelled during attempt {attempt + 1}: {url}") - return None + start_time = time.time() + response = None + error = None - # Apply rate limiting with cancellation awareness - if not self._wait_with_cancellation_check(): - print(f"Request cancelled during rate limiting: {url}") - return None + try: + self.total_requests += 1 - # AGGRESSIVE: Final check before making HTTP request - if self._is_stop_requested(): - print(f"Request cancelled before HTTP call: {url}") - return None + request_headers = dict(self.session.headers).copy() + if headers: + request_headers.update(headers) - start_time = time.time() - response = None - error = None + print(f"Making {method} request to: {url}") - try: - self.total_requests += 1 - - # Prepare request - request_headers = dict(self.session.headers).copy() - if headers: - request_headers.update(headers) - - print(f"Making {method} request to: {url} (attempt {attempt + 1})") - - # AGGRESSIVE: Use much shorter timeout if termination is requested - request_timeout = self.timeout - if self._is_stop_requested(): - request_timeout = 2 # Max 2 seconds if termination requested - print(f"Stop requested - using short timeout: {request_timeout}s") - - # Make request - if method.upper() == "GET": - response = self.session.get( - url, - params=params, - headers=request_headers, - timeout=request_timeout - ) - elif method.upper() == "POST": - response = self.session.post( - url, - json=params, - headers=request_headers, - timeout=request_timeout - ) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - print(f"Response status: {response.status_code}") - response.raise_for_status() - self.successful_requests += 1 - - # Success - log, cache, and return - duration_ms = (time.time() - start_time) * 1000 - self.logger.log_api_request( - provider=self.name, - url=url, - method=method.upper(), - status_code=response.status_code, - response_size=len(response.content), - duration_ms=duration_ms, - error=None, - target_indicator=target_indicator + if method.upper() == "GET": + response = self.session.get( + url, + params=params, + headers=request_headers, + timeout=self.timeout ) - # Cache the successful response to disk - with open(cache_path, 'w') as f: - json.dump({ - 'status_code': response.status_code, - 'content': response.text, - 'headers': dict(response.headers) - }, f) - return response + elif method.upper() == "POST": + response = self.session.post( + url, + json=params, + headers=request_headers, + timeout=self.timeout + ) + else: + raise ValueError(f"Unsupported HTTP method: {method}") - except requests.exceptions.RequestException as e: - error = str(e) - self.failed_requests += 1 - print(f"Request failed (attempt {attempt + 1}): {error}") - last_exception = e - - # AGGRESSIVE: Immediately abort retries if stop requested - if self._is_stop_requested(): - print(f"Stop requested - aborting retries for: {url}") - break - - # Check if we should retry (but only if stop not requested) - if attempt < effective_max_retries and self._should_retry(e): - # Use a longer, more respectful backoff for 429 errors - if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: - # Start with a 10-second backoff and increase exponentially - backoff_time = 10 * (2 ** attempt) - print(f"Rate limit hit. Retrying in {backoff_time} seconds...") - else: - backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors - print(f"Retrying in {backoff_time} seconds...") - - # AGGRESSIVE: Much shorter backoff and more frequent checking - if not self._sleep_with_cancellation_check(backoff_time): - print(f"Stop requested during backoff - aborting: {url}") - return None - continue - else: - break + print(f"Response status: {response.status_code}") + response.raise_for_status() + self.successful_requests += 1 + + duration_ms = (time.time() - start_time) * 1000 + self.logger.log_api_request( + provider=self.name, + url=url, + method=method.upper(), + status_code=response.status_code, + response_size=len(response.content), + duration_ms=duration_ms, + error=None, + target_indicator=target_indicator + ) + with open(cache_path, 'w') as f: + json.dump({ + 'status_code': response.status_code, + 'content': response.text, + 'headers': dict(response.headers) + }, f) + return response - except Exception as e: - error = f"Unexpected error: {str(e)}" - self.failed_requests += 1 - print(f"Unexpected error: {error}") - last_exception = e - break - - # All attempts failed - log and return None - duration_ms = (time.time() - start_time) * 1000 - self.logger.log_api_request( - provider=self.name, - url=url, - method=method.upper(), - status_code=response.status_code if response else None, - response_size=len(response.content) if response else None, - duration_ms=duration_ms, - error=error, - target_indicator=target_indicator - ) - - if error and last_exception: - raise last_exception - - return None + except requests.exceptions.RequestException as e: + error = str(e) + self.failed_requests += 1 + duration_ms = (time.time() - start_time) * 1000 + self.logger.log_api_request( + provider=self.name, + url=url, + method=method.upper(), + status_code=response.status_code if response else None, + response_size=len(response.content) if response else None, + duration_ms=duration_ms, + error=error, + target_indicator=target_indicator + ) + raise e def _is_stop_requested(self) -> bool: """ @@ -346,43 +274,6 @@ class BaseProvider(ABC): return True return False - def _wait_with_cancellation_check(self) -> bool: - """ - Wait for rate limiting while aggressively checking for cancellation. - Returns False if cancelled during wait. - """ - current_time = time.time() - time_since_last = current_time - self.rate_limiter.last_request_time - - if time_since_last < self.rate_limiter.min_interval: - sleep_time = self.rate_limiter.min_interval - time_since_last - if not self._sleep_with_cancellation_check(sleep_time): - return False - - self.rate_limiter.last_request_time = time.time() - return True - - def _sleep_with_cancellation_check(self, sleep_time: float) -> bool: - """ - Sleep for the specified time while aggressively checking for cancellation. - - Args: - sleep_time: Time to sleep in seconds - - Returns: - bool: True if sleep completed, False if cancelled - """ - sleep_start = time.time() - check_interval = 0.05 # Check every 50ms for aggressive responsiveness - - while time.time() - sleep_start < sleep_time: - if self._is_stop_requested(): - return False - remaining_time = sleep_time - (time.time() - sleep_start) - time.sleep(min(check_interval, remaining_time)) - - return True - def set_stop_event(self, stop_event: threading.Event) -> None: """ Set the stop event for this provider to enable cancellation. @@ -392,28 +283,6 @@ class BaseProvider(ABC): """ self._stop_event = stop_event - def _should_retry(self, exception: requests.exceptions.RequestException) -> bool: - """ - Determine if a request should be retried based on the exception. - - Args: - exception: The request exception that occurred - - Returns: - True if the request should be retried - """ - # Retry on connection errors and timeouts - if isinstance(exception, (requests.exceptions.ConnectionError, - requests.exceptions.Timeout)): - return True - - if isinstance(exception, requests.exceptions.HTTPError): - if hasattr(exception, 'response') and exception.response: - # Retry on server errors (5xx) AND on rate-limiting errors (429) - return exception.response.status_code >= 500 or exception.response.status_code == 429 - - return False - def log_relationship_discovery(self, source_node: str, target_node: str, relationship_type: str, confidence_score: float, diff --git a/providers/crtsh_provider.py b/providers/crtsh_provider.py index 1ff248c..56548da 100644 --- a/providers/crtsh_provider.py +++ b/providers/crtsh_provider.py @@ -1,8 +1,4 @@ -""" -Certificate Transparency provider using crt.sh. -Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking. -Stores certificates as metadata on domain nodes rather than creating certificate nodes. -""" +# dnsrecon/providers/crtsh_provider.py import json import re @@ -182,7 +178,7 @@ class CrtShProvider(BaseProvider): try: # Query crt.sh for certificates url = f"{self.base_url}?q={quote(domain)}&output=json" - response = self.make_request(url, target_indicator=domain, max_retries=3) + response = self.make_request(url, target_indicator=domain) if not response or response.status_code != 200: return [] diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 7b80d3c..579a994 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -1,7 +1,4 @@ -""" -Shodan provider for DNSRecon. -Discovers IP relationships and infrastructure context through Shodan API. -""" +# dnsrecon/providers/shodan_provider.py import json from typing import List, Dict, Any, Tuple