This commit is contained in:
overcuriousity 2025-09-14 17:40:18 +02:00
parent 949fbdbb45
commit 7fe7ca41ba
3 changed files with 67 additions and 205 deletions

View File

@ -172,22 +172,17 @@ class BaseProvider(ABC):
def make_request(self, url: str, method: str = "GET", def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
target_indicator: str = "", target_indicator: str = "") -> Optional[requests.Response]:
max_retries: int = 3) -> Optional[requests.Response]:
""" """
Make a rate-limited HTTP request with aggressive stop signal handling. Make a rate-limited HTTP request.
Terminates immediately when stop is requested, including during retries.
""" """
# Check for cancellation before starting
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Request cancelled before start: {url}") print(f"Request cancelled before start: {url}")
return None return None
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key) cache_path = os.path.join(self.cache_dir, cache_key)
# Check cache
if os.path.exists(cache_path): if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path) cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry: if cache_age < self.cache_expiry:
@ -200,143 +195,76 @@ class BaseProvider(ABC):
response.headers = cached_data['headers'] response.headers = cached_data['headers']
return response return response
# Determine effective max_retries based on stop signal self.rate_limiter.wait_if_needed()
effective_max_retries = 0 if self._is_stop_requested() else max_retries
last_exception = None
for attempt in range(effective_max_retries + 1): start_time = time.time()
# AGGRESSIVE: Check for cancellation before each attempt response = None
if self._is_stop_requested(): error = None
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
# Apply rate limiting with cancellation awareness try:
if not self._wait_with_cancellation_check(): self.total_requests += 1
print(f"Request cancelled during rate limiting: {url}")
return None
# AGGRESSIVE: Final check before making HTTP request request_headers = dict(self.session.headers).copy()
if self._is_stop_requested(): if headers:
print(f"Request cancelled before HTTP call: {url}") request_headers.update(headers)
return None
start_time = time.time() print(f"Making {method} request to: {url}")
response = None
error = None
try: if method.upper() == "GET":
self.total_requests += 1 response = self.session.get(
url,
# Prepare request params=params,
request_headers = dict(self.session.headers).copy() headers=request_headers,
if headers: timeout=self.timeout
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# AGGRESSIVE: Use much shorter timeout if termination is requested
request_timeout = self.timeout
if self._is_stop_requested():
request_timeout = 2 # Max 2 seconds if termination requested
print(f"Stop requested - using short timeout: {request_timeout}s")
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=request_timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=request_timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
) )
# Cache the successful response to disk elif method.upper() == "POST":
with open(cache_path, 'w') as f: response = self.session.post(
json.dump({ url,
'status_code': response.status_code, json=params,
'content': response.text, headers=request_headers,
'headers': dict(response.headers) timeout=self.timeout
}, f) )
return response else:
raise ValueError(f"Unsupported HTTP method: {method}")
except requests.exceptions.RequestException as e: print(f"Response status: {response.status_code}")
error = str(e) response.raise_for_status()
self.failed_requests += 1 self.successful_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
last_exception = e
# AGGRESSIVE: Immediately abort retries if stop requested duration_ms = (time.time() - start_time) * 1000
if self._is_stop_requested(): self.logger.log_api_request(
print(f"Stop requested - aborting retries for: {url}") provider=self.name,
break url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
)
with open(cache_path, 'w') as f:
json.dump({
'status_code': response.status_code,
'content': response.text,
'headers': dict(response.headers)
}, f)
return response
# Check if we should retry (but only if stop not requested) except requests.exceptions.RequestException as e:
if attempt < effective_max_retries and self._should_retry(e): error = str(e)
# Use a longer, more respectful backoff for 429 errors self.failed_requests += 1
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: duration_ms = (time.time() - start_time) * 1000
# Start with a 10-second backoff and increase exponentially self.logger.log_api_request(
backoff_time = 10 * (2 ** attempt) provider=self.name,
print(f"Rate limit hit. Retrying in {backoff_time} seconds...") url=url,
else: method=method.upper(),
backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors status_code=response.status_code if response else None,
print(f"Retrying in {backoff_time} seconds...") response_size=len(response.content) if response else None,
duration_ms=duration_ms,
# AGGRESSIVE: Much shorter backoff and more frequent checking error=error,
if not self._sleep_with_cancellation_check(backoff_time): target_indicator=target_indicator
print(f"Stop requested during backoff - aborting: {url}") )
return None raise e
continue
else:
break
except Exception as e:
error = f"Unexpected error: {str(e)}"
self.failed_requests += 1
print(f"Unexpected error: {error}")
last_exception = e
break
# All attempts failed - log and return None
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code if response else None,
response_size=len(response.content) if response else None,
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
if error and last_exception:
raise last_exception
return None
def _is_stop_requested(self) -> bool: def _is_stop_requested(self) -> bool:
""" """
@ -346,43 +274,6 @@ class BaseProvider(ABC):
return True return True
return False return False
def _wait_with_cancellation_check(self) -> bool:
"""
Wait for rate limiting while aggressively checking for cancellation.
Returns False if cancelled during wait.
"""
current_time = time.time()
time_since_last = current_time - self.rate_limiter.last_request_time
if time_since_last < self.rate_limiter.min_interval:
sleep_time = self.rate_limiter.min_interval - time_since_last
if not self._sleep_with_cancellation_check(sleep_time):
return False
self.rate_limiter.last_request_time = time.time()
return True
def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
"""
Sleep for the specified time while aggressively checking for cancellation.
Args:
sleep_time: Time to sleep in seconds
Returns:
bool: True if sleep completed, False if cancelled
"""
sleep_start = time.time()
check_interval = 0.05 # Check every 50ms for aggressive responsiveness
while time.time() - sleep_start < sleep_time:
if self._is_stop_requested():
return False
remaining_time = sleep_time - (time.time() - sleep_start)
time.sleep(min(check_interval, remaining_time))
return True
def set_stop_event(self, stop_event: threading.Event) -> None: def set_stop_event(self, stop_event: threading.Event) -> None:
""" """
Set the stop event for this provider to enable cancellation. Set the stop event for this provider to enable cancellation.
@ -392,28 +283,6 @@ class BaseProvider(ABC):
""" """
self._stop_event = stop_event self._stop_event = stop_event
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
Determine if a request should be retried based on the exception.
Args:
exception: The request exception that occurred
Returns:
True if the request should be retried
"""
# Retry on connection errors and timeouts
if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)):
return True
if isinstance(exception, requests.exceptions.HTTPError):
if hasattr(exception, 'response') and exception.response:
# Retry on server errors (5xx) AND on rate-limiting errors (429)
return exception.response.status_code >= 500 or exception.response.status_code == 429
return False
def log_relationship_discovery(self, source_node: str, target_node: str, def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str, relationship_type: str,
confidence_score: float, confidence_score: float,

View File

@ -1,8 +1,4 @@
""" # dnsrecon/providers/crtsh_provider.py
Certificate Transparency provider using crt.sh.
Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking.
Stores certificates as metadata on domain nodes rather than creating certificate nodes.
"""
import json import json
import re import re
@ -182,7 +178,7 @@ class CrtShProvider(BaseProvider):
try: try:
# Query crt.sh for certificates # Query crt.sh for certificates
url = f"{self.base_url}?q={quote(domain)}&output=json" url = f"{self.base_url}?q={quote(domain)}&output=json"
response = self.make_request(url, target_indicator=domain, max_retries=3) response = self.make_request(url, target_indicator=domain)
if not response or response.status_code != 200: if not response or response.status_code != 200:
return [] return []

View File

@ -1,7 +1,4 @@
""" # dnsrecon/providers/shodan_provider.py
Shodan provider for DNSRecon.
Discovers IP relationships and infrastructure context through Shodan API.
"""
import json import json
from typing import List, Dict, Any, Tuple from typing import List, Dict, Any, Tuple