This commit is contained in:
overcuriousity 2025-09-14 17:40:18 +02:00
parent 949fbdbb45
commit 7fe7ca41ba
3 changed files with 67 additions and 205 deletions

View File

@ -172,22 +172,17 @@ class BaseProvider(ABC):
def make_request(self, url: str, method: str = "GET", def make_request(self, url: str, method: str = "GET",
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
target_indicator: str = "", target_indicator: str = "") -> Optional[requests.Response]:
max_retries: int = 3) -> Optional[requests.Response]:
""" """
Make a rate-limited HTTP request with aggressive stop signal handling. Make a rate-limited HTTP request.
Terminates immediately when stop is requested, including during retries.
""" """
# Check for cancellation before starting
if self._is_stop_requested(): if self._is_stop_requested():
print(f"Request cancelled before start: {url}") print(f"Request cancelled before start: {url}")
return None return None
# Create a unique cache key
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json" cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
cache_path = os.path.join(self.cache_dir, cache_key) cache_path = os.path.join(self.cache_dir, cache_key)
# Check cache
if os.path.exists(cache_path): if os.path.exists(cache_path):
cache_age = time.time() - os.path.getmtime(cache_path) cache_age = time.time() - os.path.getmtime(cache_path)
if cache_age < self.cache_expiry: if cache_age < self.cache_expiry:
@ -200,143 +195,76 @@ class BaseProvider(ABC):
response.headers = cached_data['headers'] response.headers = cached_data['headers']
return response return response
# Determine effective max_retries based on stop signal self.rate_limiter.wait_if_needed()
effective_max_retries = 0 if self._is_stop_requested() else max_retries
last_exception = None
for attempt in range(effective_max_retries + 1): start_time = time.time()
# AGGRESSIVE: Check for cancellation before each attempt response = None
if self._is_stop_requested(): error = None
print(f"Request cancelled during attempt {attempt + 1}: {url}")
return None
# Apply rate limiting with cancellation awareness try:
if not self._wait_with_cancellation_check(): self.total_requests += 1
print(f"Request cancelled during rate limiting: {url}")
return None
# AGGRESSIVE: Final check before making HTTP request request_headers = dict(self.session.headers).copy()
if self._is_stop_requested(): if headers:
print(f"Request cancelled before HTTP call: {url}") request_headers.update(headers)
return None
start_time = time.time() print(f"Making {method} request to: {url}")
response = None
error = None
try: if method.upper() == "GET":
self.total_requests += 1 response = self.session.get(
url,
# Prepare request params=params,
request_headers = dict(self.session.headers).copy() headers=request_headers,
if headers: timeout=self.timeout
request_headers.update(headers)
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
# AGGRESSIVE: Use much shorter timeout if termination is requested
request_timeout = self.timeout
if self._is_stop_requested():
request_timeout = 2 # Max 2 seconds if termination requested
print(f"Stop requested - using short timeout: {request_timeout}s")
# Make request
if method.upper() == "GET":
response = self.session.get(
url,
params=params,
headers=request_headers,
timeout=request_timeout
)
elif method.upper() == "POST":
response = self.session.post(
url,
json=params,
headers=request_headers,
timeout=request_timeout
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
print(f"Response status: {response.status_code}")
response.raise_for_status()
self.successful_requests += 1
# Success - log, cache, and return
duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
provider=self.name,
url=url,
method=method.upper(),
status_code=response.status_code,
response_size=len(response.content),
duration_ms=duration_ms,
error=None,
target_indicator=target_indicator
) )
# Cache the successful response to disk elif method.upper() == "POST":
with open(cache_path, 'w') as f: response = self.session.post(
json.dump({ url,
'status_code': response.status_code, json=params,
'content': response.text, headers=request_headers,
'headers': dict(response.headers) timeout=self.timeout
}, f) )
return response else:
raise ValueError(f"Unsupported HTTP method: {method}")
except requests.exceptions.RequestException as e: print(f"Response status: {response.status_code}")
error = str(e) response.raise_for_status()
self.failed_requests += 1 self.successful_requests += 1
print(f"Request failed (attempt {attempt + 1}): {error}")
last_exception = e duration_ms = (time.time() - start_time) * 1000
self.logger.log_api_request(
# AGGRESSIVE: Immediately abort retries if stop requested provider=self.name,
if self._is_stop_requested(): url=url,
print(f"Stop requested - aborting retries for: {url}") method=method.upper(),
break status_code=response.status_code,
response_size=len(response.content),
# Check if we should retry (but only if stop not requested) duration_ms=duration_ms,
if attempt < effective_max_retries and self._should_retry(e): error=None,
# Use a longer, more respectful backoff for 429 errors target_indicator=target_indicator
if isinstance(e, requests.exceptions.HTTPError) and e.response and e.response.status_code == 429: )
# Start with a 10-second backoff and increase exponentially with open(cache_path, 'w') as f:
backoff_time = 10 * (2 ** attempt) json.dump({
print(f"Rate limit hit. Retrying in {backoff_time} seconds...") 'status_code': response.status_code,
else: 'content': response.text,
backoff_time = min(1.0, (2 ** attempt) * 0.5) # Shorter backoff for other errors 'headers': dict(response.headers)
print(f"Retrying in {backoff_time} seconds...") }, f)
return response
# AGGRESSIVE: Much shorter backoff and more frequent checking
if not self._sleep_with_cancellation_check(backoff_time):
print(f"Stop requested during backoff - aborting: {url}")
return None
continue
else:
break
except Exception as e: except requests.exceptions.RequestException as e:
error = f"Unexpected error: {str(e)}" error = str(e)
self.failed_requests += 1 self.failed_requests += 1
print(f"Unexpected error: {error}") duration_ms = (time.time() - start_time) * 1000
last_exception = e self.logger.log_api_request(
break provider=self.name,
url=url,
# All attempts failed - log and return None method=method.upper(),
duration_ms = (time.time() - start_time) * 1000 status_code=response.status_code if response else None,
self.logger.log_api_request( response_size=len(response.content) if response else None,
provider=self.name, duration_ms=duration_ms,
url=url, error=error,
method=method.upper(), target_indicator=target_indicator
status_code=response.status_code if response else None, )
response_size=len(response.content) if response else None, raise e
duration_ms=duration_ms,
error=error,
target_indicator=target_indicator
)
if error and last_exception:
raise last_exception
return None
def _is_stop_requested(self) -> bool: def _is_stop_requested(self) -> bool:
""" """
@ -346,43 +274,6 @@ class BaseProvider(ABC):
return True return True
return False return False
def _wait_with_cancellation_check(self) -> bool:
"""
Wait for rate limiting while aggressively checking for cancellation.
Returns False if cancelled during wait.
"""
current_time = time.time()
time_since_last = current_time - self.rate_limiter.last_request_time
if time_since_last < self.rate_limiter.min_interval:
sleep_time = self.rate_limiter.min_interval - time_since_last
if not self._sleep_with_cancellation_check(sleep_time):
return False
self.rate_limiter.last_request_time = time.time()
return True
def _sleep_with_cancellation_check(self, sleep_time: float) -> bool:
"""
Sleep for the specified time while aggressively checking for cancellation.
Args:
sleep_time: Time to sleep in seconds
Returns:
bool: True if sleep completed, False if cancelled
"""
sleep_start = time.time()
check_interval = 0.05 # Check every 50ms for aggressive responsiveness
while time.time() - sleep_start < sleep_time:
if self._is_stop_requested():
return False
remaining_time = sleep_time - (time.time() - sleep_start)
time.sleep(min(check_interval, remaining_time))
return True
def set_stop_event(self, stop_event: threading.Event) -> None: def set_stop_event(self, stop_event: threading.Event) -> None:
""" """
Set the stop event for this provider to enable cancellation. Set the stop event for this provider to enable cancellation.
@ -392,28 +283,6 @@ class BaseProvider(ABC):
""" """
self._stop_event = stop_event self._stop_event = stop_event
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
"""
Determine if a request should be retried based on the exception.
Args:
exception: The request exception that occurred
Returns:
True if the request should be retried
"""
# Retry on connection errors and timeouts
if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)):
return True
if isinstance(exception, requests.exceptions.HTTPError):
if hasattr(exception, 'response') and exception.response:
# Retry on server errors (5xx) AND on rate-limiting errors (429)
return exception.response.status_code >= 500 or exception.response.status_code == 429
return False
def log_relationship_discovery(self, source_node: str, target_node: str, def log_relationship_discovery(self, source_node: str, target_node: str,
relationship_type: str, relationship_type: str,
confidence_score: float, confidence_score: float,

View File

@ -1,8 +1,4 @@
""" # dnsrecon/providers/crtsh_provider.py
Certificate Transparency provider using crt.sh.
Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking.
Stores certificates as metadata on domain nodes rather than creating certificate nodes.
"""
import json import json
import re import re
@ -182,7 +178,7 @@ class CrtShProvider(BaseProvider):
try: try:
# Query crt.sh for certificates # Query crt.sh for certificates
url = f"{self.base_url}?q={quote(domain)}&output=json" url = f"{self.base_url}?q={quote(domain)}&output=json"
response = self.make_request(url, target_indicator=domain, max_retries=3) response = self.make_request(url, target_indicator=domain)
if not response or response.status_code != 200: if not response or response.status_code != 200:
return [] return []

View File

@ -1,7 +1,4 @@
""" # dnsrecon/providers/shodan_provider.py
Shodan provider for DNSRecon.
Discovers IP relationships and infrastructure context through Shodan API.
"""
import json import json
from typing import List, Dict, Any, Tuple from typing import List, Dict, Any, Tuple