it
This commit is contained in:
@@ -7,10 +7,9 @@ import os
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from core.logger import get_forensic_logger
|
||||
from core.graph_manager import NodeType, RelationshipType
|
||||
from core.graph_manager import RelationshipType
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
@@ -42,36 +41,52 @@ class RateLimiter:
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for all DNSRecon data providers.
|
||||
Provides common functionality and defines the provider interface.
|
||||
Now supports session-specific configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30):
|
||||
def __init__(self, name: str, rate_limit: int = 60, timeout: int = 30, session_config=None):
|
||||
"""
|
||||
Initialize base provider.
|
||||
Initialize base provider with session-specific configuration.
|
||||
|
||||
Args:
|
||||
name: Provider name for logging
|
||||
rate_limit: Requests per minute limit
|
||||
rate_limit: Requests per minute limit (default override)
|
||||
timeout: Request timeout in seconds
|
||||
session_config: Session-specific configuration
|
||||
"""
|
||||
# Use session config if provided, otherwise fall back to global config
|
||||
if session_config is not None:
|
||||
self.config = session_config
|
||||
actual_rate_limit = self.config.get_rate_limit(name)
|
||||
actual_timeout = self.config.default_timeout
|
||||
else:
|
||||
# Fallback to global config for backwards compatibility
|
||||
from config import config as global_config
|
||||
self.config = global_config
|
||||
actual_rate_limit = rate_limit
|
||||
actual_timeout = timeout
|
||||
|
||||
self.name = name
|
||||
self.rate_limiter = RateLimiter(rate_limit)
|
||||
self.timeout = timeout
|
||||
self.rate_limiter = RateLimiter(actual_rate_limit)
|
||||
self.timeout = actual_timeout
|
||||
self._local = threading.local()
|
||||
self.logger = get_forensic_logger()
|
||||
self._stop_event = None
|
||||
|
||||
# Caching configuration
|
||||
self.cache_dir = '.cache'
|
||||
# Caching configuration (per session)
|
||||
self.cache_dir = f'.cache/{id(self.config)}' # Unique cache per session config
|
||||
self.cache_expiry = 12 * 3600 # 12 hours in seconds
|
||||
if not os.path.exists(self.cache_dir):
|
||||
os.makedirs(self.cache_dir)
|
||||
|
||||
# Statistics
|
||||
# Statistics (per provider instance)
|
||||
self.total_requests = 0
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_relationships_found = 0
|
||||
|
||||
print(f"Initialized {name} provider with session-specific config (rate: {actual_rate_limit}/min)")
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if not hasattr(self._local, 'session'):
|
||||
@@ -118,136 +133,174 @@ class BaseProvider(ABC):
|
||||
pass
|
||||
|
||||
def make_request(self, url: str, method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
target_indicator: str = "",
|
||||
max_retries: int = 3) -> Optional[requests.Response]:
|
||||
"""
|
||||
Make a rate-limited HTTP request with forensic logging and retry logic.
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
target_indicator: str = "",
|
||||
max_retries: int = 3) -> Optional[requests.Response]:
|
||||
"""
|
||||
Make a rate-limited HTTP request with forensic logging and retry logic.
|
||||
Now supports cancellation via stop_event from scanner.
|
||||
"""
|
||||
# Check for cancellation before starting
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
print(f"Request cancelled before start: {url}")
|
||||
return None
|
||||
|
||||
Args:
|
||||
url: Request URL
|
||||
method: HTTP method
|
||||
params: Query parameters
|
||||
headers: Additional headers
|
||||
target_indicator: The indicator being investigated
|
||||
max_retries: Maximum number of retry attempts
|
||||
# Create a unique cache key
|
||||
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
|
||||
cache_path = os.path.join(self.cache_dir, cache_key)
|
||||
|
||||
Returns:
|
||||
Response object or None if request failed
|
||||
"""
|
||||
# Create a unique cache key
|
||||
cache_key = f"{self.name}_{hash(f'{method}:{url}:{json.dumps(params, sort_keys=True)}')}.json"
|
||||
cache_path = os.path.join(self.cache_dir, cache_key)
|
||||
|
||||
# Check cache
|
||||
if os.path.exists(cache_path):
|
||||
cache_age = time.time() - os.path.getmtime(cache_path)
|
||||
if cache_age < self.cache_expiry:
|
||||
print(f"Returning cached response for: {url}")
|
||||
with open(cache_path, 'r') as f:
|
||||
cached_data = json.load(f)
|
||||
response = requests.Response()
|
||||
response.status_code = cached_data['status_code']
|
||||
response._content = cached_data['content'].encode('utf-8')
|
||||
response.headers = cached_data['headers']
|
||||
return response
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
# Apply rate limiting
|
||||
self.rate_limiter.wait_if_needed()
|
||||
|
||||
start_time = time.time()
|
||||
response = None
|
||||
error = None
|
||||
|
||||
try:
|
||||
self.total_requests += 1
|
||||
|
||||
# Prepare request
|
||||
request_headers = self.session.headers.copy()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
|
||||
|
||||
# Make request
|
||||
if method.upper() == "GET":
|
||||
response = self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = self.session.post(
|
||||
url,
|
||||
json=params,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
print(f"Response status: {response.status_code}")
|
||||
response.raise_for_status()
|
||||
self.successful_requests += 1
|
||||
|
||||
# Success - log, cache, and return
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.log_api_request(
|
||||
provider=self.name,
|
||||
url=url,
|
||||
method=method.upper(),
|
||||
status_code=response.status_code,
|
||||
response_size=len(response.content),
|
||||
duration_ms=duration_ms,
|
||||
error=None,
|
||||
target_indicator=target_indicator
|
||||
)
|
||||
# Cache the successful response to disk
|
||||
with open(cache_path, 'w') as f:
|
||||
json.dump({
|
||||
'status_code': response.status_code,
|
||||
'content': response.text,
|
||||
'headers': dict(response.headers)
|
||||
}, f)
|
||||
# Check cache
|
||||
if os.path.exists(cache_path):
|
||||
cache_age = time.time() - os.path.getmtime(cache_path)
|
||||
if cache_age < self.cache_expiry:
|
||||
print(f"Returning cached response for: {url}")
|
||||
with open(cache_path, 'r') as f:
|
||||
cached_data = json.load(f)
|
||||
response = requests.Response()
|
||||
response.status_code = cached_data['status_code']
|
||||
response._content = cached_data['content'].encode('utf-8')
|
||||
response.headers = cached_data['headers']
|
||||
return response
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error = str(e)
|
||||
self.failed_requests += 1
|
||||
print(f"Request failed (attempt {attempt + 1}): {error}")
|
||||
|
||||
# Check if we should retry
|
||||
if attempt < max_retries and self._should_retry(e):
|
||||
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
|
||||
print(f"Retrying in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
for attempt in range(max_retries + 1):
|
||||
# Check for cancellation before each attempt
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
print(f"Request cancelled during attempt {attempt + 1}: {url}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
error = f"Unexpected error: {str(e)}"
|
||||
self.failed_requests += 1
|
||||
print(f"Unexpected error: {error}")
|
||||
# Apply rate limiting (but reduce wait time if cancellation is requested)
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
break
|
||||
|
||||
self.rate_limiter.wait_if_needed()
|
||||
|
||||
# Check again after rate limiting
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
print(f"Request cancelled after rate limiting: {url}")
|
||||
return None
|
||||
|
||||
start_time = time.time()
|
||||
response = None
|
||||
error = None
|
||||
|
||||
try:
|
||||
self.total_requests += 1
|
||||
|
||||
# Prepare request
|
||||
request_headers = self.session.headers.copy()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
print(f"Making {method} request to: {url} (attempt {attempt + 1})")
|
||||
|
||||
# Use shorter timeout if termination is requested
|
||||
request_timeout = self.timeout
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
request_timeout = min(5, self.timeout) # Max 5 seconds if termination requested
|
||||
|
||||
# Make request
|
||||
if method.upper() == "GET":
|
||||
response = self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=request_headers,
|
||||
timeout=request_timeout
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = self.session.post(
|
||||
url,
|
||||
json=params,
|
||||
headers=request_headers,
|
||||
timeout=request_timeout
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
print(f"Response status: {response.status_code}")
|
||||
response.raise_for_status()
|
||||
self.successful_requests += 1
|
||||
|
||||
# Success - log, cache, and return
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.log_api_request(
|
||||
provider=self.name,
|
||||
url=url,
|
||||
method=method.upper(),
|
||||
status_code=response.status_code,
|
||||
response_size=len(response.content),
|
||||
duration_ms=duration_ms,
|
||||
error=None,
|
||||
target_indicator=target_indicator
|
||||
)
|
||||
# Cache the successful response to disk
|
||||
with open(cache_path, 'w') as f:
|
||||
json.dump({
|
||||
'status_code': response.status_code,
|
||||
'content': response.text,
|
||||
'headers': dict(response.headers)
|
||||
}, f)
|
||||
return response
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error = str(e)
|
||||
self.failed_requests += 1
|
||||
print(f"Request failed (attempt {attempt + 1}): {error}")
|
||||
|
||||
# Check for cancellation before retrying
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
print(f"Request cancelled, not retrying: {url}")
|
||||
break
|
||||
|
||||
# Check if we should retry
|
||||
if attempt < max_retries and self._should_retry(e):
|
||||
backoff_time = (2 ** attempt) * 1 # Exponential backoff: 1s, 2s, 4s
|
||||
print(f"Retrying in {backoff_time} seconds...")
|
||||
|
||||
# Shorter backoff if termination is requested
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
backoff_time = min(0.5, backoff_time)
|
||||
|
||||
# Sleep with cancellation checking
|
||||
sleep_start = time.time()
|
||||
while time.time() - sleep_start < backoff_time:
|
||||
if hasattr(self, '_stop_event') and self._stop_event and self._stop_event.is_set():
|
||||
print(f"Request cancelled during backoff: {url}")
|
||||
return None
|
||||
time.sleep(0.1) # Check every 100ms
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
# All attempts failed - log and return None
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.log_api_request(
|
||||
provider=self.name,
|
||||
url=url,
|
||||
method=method.upper(),
|
||||
status_code=response.status_code if response else None,
|
||||
response_size=len(response.content) if response else None,
|
||||
duration_ms=duration_ms,
|
||||
error=error,
|
||||
target_indicator=target_indicator
|
||||
)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
error = f"Unexpected error: {str(e)}"
|
||||
self.failed_requests += 1
|
||||
print(f"Unexpected error: {error}")
|
||||
break
|
||||
|
||||
# All attempts failed - log and return None
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.log_api_request(
|
||||
provider=self.name,
|
||||
url=url,
|
||||
method=method.upper(),
|
||||
status_code=response.status_code if response else None,
|
||||
response_size=len(response.content) if response else None,
|
||||
duration_ms=duration_ms,
|
||||
error=error,
|
||||
target_indicator=target_indicator
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def set_stop_event(self, stop_event: threading.Event) -> None:
|
||||
"""
|
||||
Set the stop event for this provider to enable cancellation.
|
||||
|
||||
Args:
|
||||
stop_event: Threading event to signal cancellation
|
||||
"""
|
||||
self._stop_event = stop_event
|
||||
|
||||
def _should_retry(self, exception: requests.exceptions.RequestException) -> bool:
|
||||
"""
|
||||
@@ -314,90 +367,4 @@ class BaseProvider(ABC):
|
||||
'success_rate': (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0,
|
||||
'relationships_found': self.total_relationships_found,
|
||||
'rate_limit': self.rate_limiter.requests_per_minute
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""Reset provider statistics."""
|
||||
self.total_requests = 0
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_relationships_found = 0
|
||||
|
||||
def _extract_domain_from_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Extract domain from URL.
|
||||
|
||||
Args:
|
||||
url: URL string
|
||||
|
||||
Returns:
|
||||
Domain name or None if extraction fails
|
||||
"""
|
||||
try:
|
||||
# Remove protocol
|
||||
if '://' in url:
|
||||
url = url.split('://', 1)[1]
|
||||
|
||||
# Remove path
|
||||
if '/' in url:
|
||||
url = url.split('/', 1)[0]
|
||||
|
||||
# Remove port
|
||||
if ':' in url:
|
||||
url = url.split(':', 1)[0]
|
||||
|
||||
return url.lower()
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _is_valid_domain(self, domain: str) -> bool:
|
||||
"""
|
||||
Basic domain validation.
|
||||
|
||||
Args:
|
||||
domain: Domain string to validate
|
||||
|
||||
Returns:
|
||||
True if domain appears valid
|
||||
"""
|
||||
if not domain or len(domain) > 253:
|
||||
return False
|
||||
|
||||
# Check for valid characters and structure
|
||||
parts = domain.split('.')
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
for part in parts:
|
||||
if not part or len(part) > 63:
|
||||
return False
|
||||
if not part.replace('-', '').replace('_', '').isalnum():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_valid_ip(self, ip: str) -> bool:
|
||||
"""
|
||||
Basic IP address validation.
|
||||
|
||||
Args:
|
||||
ip: IP address string to validate
|
||||
|
||||
Returns:
|
||||
True if IP appears valid
|
||||
"""
|
||||
try:
|
||||
parts = ip.split('.')
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
|
||||
for part in parts:
|
||||
num = int(part)
|
||||
if not 0 <= num <= 255:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Certificate Transparency provider using crt.sh.
|
||||
Discovers domain relationships through certificate SAN analysis.
|
||||
Discovers domain relationships through certificate SAN analysis with comprehensive certificate tracking.
|
||||
Stores certificates as metadata on domain nodes rather than creating certificate nodes.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -10,23 +11,26 @@ from urllib.parse import quote
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .base_provider import BaseProvider
|
||||
from utils.helpers import _is_valid_domain
|
||||
from core.graph_manager import RelationshipType
|
||||
|
||||
|
||||
class CrtShProvider(BaseProvider):
|
||||
"""
|
||||
Provider for querying crt.sh certificate transparency database.
|
||||
Discovers domain relationships through certificate Subject Alternative Names (SANs).
|
||||
Now uses session-specific configuration and caching.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CrtSh provider with appropriate rate limiting."""
|
||||
def __init__(self, session_config=None):
|
||||
"""Initialize CrtSh provider with session-specific configuration."""
|
||||
super().__init__(
|
||||
name="crtsh",
|
||||
rate_limit=60, # Be respectful to the free service
|
||||
timeout=30
|
||||
rate_limit=60,
|
||||
timeout=15,
|
||||
session_config=session_config
|
||||
)
|
||||
self.base_url = "https://crt.sh/"
|
||||
self._stop_event = None
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
@@ -40,31 +44,128 @@ class CrtShProvider(BaseProvider):
|
||||
"""
|
||||
return True
|
||||
|
||||
def _parse_certificate_date(self, date_string: str) -> datetime:
|
||||
"""
|
||||
Parse certificate date from crt.sh format.
|
||||
|
||||
Args:
|
||||
date_string: Date string from crt.sh API
|
||||
|
||||
Returns:
|
||||
Parsed datetime object in UTC
|
||||
"""
|
||||
if not date_string:
|
||||
raise ValueError("Empty date string")
|
||||
|
||||
try:
|
||||
# Handle various possible formats from crt.sh
|
||||
if date_string.endswith('Z'):
|
||||
return datetime.fromisoformat(date_string[:-1]).replace(tzinfo=timezone.utc)
|
||||
elif '+' in date_string or date_string.endswith('UTC'):
|
||||
# Handle timezone-aware strings
|
||||
date_string = date_string.replace('UTC', '').strip()
|
||||
if '+' in date_string:
|
||||
date_string = date_string.split('+')[0]
|
||||
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Assume UTC if no timezone specified
|
||||
return datetime.fromisoformat(date_string).replace(tzinfo=timezone.utc)
|
||||
except Exception as e:
|
||||
# Fallback: try parsing without timezone info and assume UTC
|
||||
try:
|
||||
return datetime.strptime(date_string[:19], "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc)
|
||||
except Exception:
|
||||
raise ValueError(f"Unable to parse date: {date_string}") from e
|
||||
|
||||
def _is_cert_valid(self, cert_data: Dict[str, Any]) -> bool:
|
||||
"""Check if a certificate is currently valid."""
|
||||
"""
|
||||
Check if a certificate is currently valid based on its expiry date.
|
||||
|
||||
Args:
|
||||
cert_data: Certificate data from crt.sh
|
||||
|
||||
Returns:
|
||||
True if certificate is currently valid (not expired)
|
||||
"""
|
||||
try:
|
||||
not_after_str = cert_data.get('not_after')
|
||||
if not_after_str:
|
||||
# Append 'Z' to indicate UTC if it's not present
|
||||
if not not_after_str.endswith('Z'):
|
||||
not_after_str += 'Z'
|
||||
not_after_date = datetime.fromisoformat(not_after_str.replace('Z', '+00:00'))
|
||||
return not_after_date > datetime.now(timezone.utc)
|
||||
except Exception:
|
||||
if not not_after_str:
|
||||
return False
|
||||
|
||||
not_after_date = self._parse_certificate_date(not_after_str)
|
||||
not_before_str = cert_data.get('not_before')
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check if certificate is within valid date range
|
||||
is_not_expired = not_after_date > now
|
||||
|
||||
if not_before_str:
|
||||
not_before_date = self._parse_certificate_date(not_before_str)
|
||||
is_not_before_valid = not_before_date <= now
|
||||
return is_not_expired and is_not_before_valid
|
||||
|
||||
return is_not_expired
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.debug(f"Certificate validity check failed: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def _extract_certificate_metadata(self, cert_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract comprehensive metadata from certificate data.
|
||||
|
||||
Args:
|
||||
cert_data: Raw certificate data from crt.sh
|
||||
|
||||
Returns:
|
||||
Comprehensive certificate metadata dictionary
|
||||
"""
|
||||
metadata = {
|
||||
'certificate_id': cert_data.get('id'),
|
||||
'serial_number': cert_data.get('serial_number'),
|
||||
'issuer_name': cert_data.get('issuer_name'),
|
||||
'issuer_ca_id': cert_data.get('issuer_ca_id'),
|
||||
'common_name': cert_data.get('common_name'),
|
||||
'not_before': cert_data.get('not_before'),
|
||||
'not_after': cert_data.get('not_after'),
|
||||
'entry_timestamp': cert_data.get('entry_timestamp'),
|
||||
'source': 'crt.sh'
|
||||
}
|
||||
|
||||
# Add computed fields
|
||||
try:
|
||||
if metadata['not_before'] and metadata['not_after']:
|
||||
not_before = self._parse_certificate_date(metadata['not_before'])
|
||||
not_after = self._parse_certificate_date(metadata['not_after'])
|
||||
|
||||
metadata['validity_period_days'] = (not_after - not_before).days
|
||||
metadata['is_currently_valid'] = self._is_cert_valid(cert_data)
|
||||
metadata['expires_soon'] = (not_after - datetime.now(timezone.utc)).days <= 30
|
||||
|
||||
# Add human-readable dates
|
||||
metadata['not_before_formatted'] = not_before.strftime('%Y-%m-%d %H:%M:%S UTC')
|
||||
metadata['not_after_formatted'] = not_after.strftime('%Y-%m-%d %H:%M:%S UTC')
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.debug(f"Error computing certificate metadata: {e}")
|
||||
metadata['is_currently_valid'] = False
|
||||
metadata['expires_soon'] = False
|
||||
|
||||
return metadata
|
||||
|
||||
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Query crt.sh for certificates containing the domain.
|
||||
|
||||
Args:
|
||||
domain: Domain to investigate
|
||||
|
||||
Returns:
|
||||
List of relationships discovered from certificate analysis
|
||||
Creates domain-to-domain relationships and stores certificate data as metadata.
|
||||
Now supports early termination via stop_event.
|
||||
"""
|
||||
if not self._is_valid_domain(domain):
|
||||
if not _is_valid_domain(domain):
|
||||
return []
|
||||
|
||||
# Check for cancellation before starting
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
print(f"CrtSh query cancelled before start for domain: {domain}")
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -72,56 +173,113 @@ class CrtShProvider(BaseProvider):
|
||||
try:
|
||||
# Query crt.sh for certificates
|
||||
url = f"{self.base_url}?q={quote(domain)}&output=json"
|
||||
response = self.make_request(url, target_indicator=domain)
|
||||
response = self.make_request(url, target_indicator=domain, max_retries=1) # Reduce retries for faster cancellation
|
||||
|
||||
if not response or response.status_code != 200:
|
||||
return []
|
||||
|
||||
# Check for cancellation after request
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
print(f"CrtSh query cancelled after request for domain: {domain}")
|
||||
return []
|
||||
|
||||
certificates = response.json()
|
||||
|
||||
if not certificates:
|
||||
return []
|
||||
|
||||
# Process certificates to extract relationships
|
||||
discovered_subdomains = {}
|
||||
# Check for cancellation before processing
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
print(f"CrtSh query cancelled before processing for domain: {domain}")
|
||||
return []
|
||||
|
||||
for cert_data in certificates:
|
||||
# Aggregate certificate data by domain
|
||||
domain_certificates = {}
|
||||
all_discovered_domains = set()
|
||||
|
||||
# Process certificates and group by domain (with cancellation checks)
|
||||
for i, cert_data in enumerate(certificates):
|
||||
# Check for cancellation every 10 certificates
|
||||
if i % 10 == 0 and self._stop_event and self._stop_event.is_set():
|
||||
print(f"CrtSh processing cancelled at certificate {i} for domain: {domain}")
|
||||
break
|
||||
|
||||
cert_metadata = self._extract_certificate_metadata(cert_data)
|
||||
cert_domains = self._extract_domains_from_certificate(cert_data)
|
||||
is_valid = self._is_cert_valid(cert_data)
|
||||
|
||||
# Add all domains from this certificate to our tracking
|
||||
for cert_domain in cert_domains:
|
||||
if not _is_valid_domain(cert_domain):
|
||||
continue
|
||||
|
||||
all_discovered_domains.add(cert_domain)
|
||||
|
||||
# Initialize domain certificate list if needed
|
||||
if cert_domain not in domain_certificates:
|
||||
domain_certificates[cert_domain] = []
|
||||
|
||||
# Add this certificate to the domain's certificate list
|
||||
domain_certificates[cert_domain].append(cert_metadata)
|
||||
|
||||
# Final cancellation check before creating relationships
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
print(f"CrtSh query cancelled before relationship creation for domain: {domain}")
|
||||
return []
|
||||
|
||||
for subdomain in cert_domains:
|
||||
if self._is_valid_domain(subdomain) and subdomain != domain:
|
||||
if subdomain not in discovered_subdomains:
|
||||
discovered_subdomains[subdomain] = {'has_valid_cert': False, 'issuers': set()}
|
||||
|
||||
if is_valid:
|
||||
discovered_subdomains[subdomain]['has_valid_cert'] = True
|
||||
|
||||
issuer = cert_data.get('issuer_name')
|
||||
if issuer:
|
||||
discovered_subdomains[subdomain]['issuers'].add(issuer)
|
||||
# Create relationships from query domain to ALL discovered domains
|
||||
for discovered_domain in all_discovered_domains:
|
||||
if discovered_domain == domain:
|
||||
continue # Skip self-relationships
|
||||
|
||||
# Check for cancellation during relationship creation
|
||||
if self._stop_event and self._stop_event.is_set():
|
||||
print(f"CrtSh relationship creation cancelled for domain: {domain}")
|
||||
break
|
||||
|
||||
# Create relationships from the discovered subdomains
|
||||
for subdomain, data in discovered_subdomains.items():
|
||||
raw_data = {
|
||||
'has_valid_cert': data['has_valid_cert'],
|
||||
'issuers': list(data['issuers']),
|
||||
'source': 'crt.sh'
|
||||
if not _is_valid_domain(discovered_domain):
|
||||
continue
|
||||
|
||||
# Get certificates for both domains
|
||||
query_domain_certs = domain_certificates.get(domain, [])
|
||||
discovered_domain_certs = domain_certificates.get(discovered_domain, [])
|
||||
|
||||
# Find shared certificates (for metadata purposes)
|
||||
shared_certificates = self._find_shared_certificates(query_domain_certs, discovered_domain_certs)
|
||||
|
||||
# Calculate confidence based on relationship type and shared certificates
|
||||
confidence = self._calculate_domain_relationship_confidence(
|
||||
domain, discovered_domain, shared_certificates, all_discovered_domains
|
||||
)
|
||||
|
||||
# Create comprehensive raw data for the relationship
|
||||
relationship_raw_data = {
|
||||
'relationship_type': 'certificate_discovery',
|
||||
'shared_certificates': shared_certificates,
|
||||
'total_shared_certs': len(shared_certificates),
|
||||
'discovery_context': self._determine_relationship_context(discovered_domain, domain),
|
||||
'domain_certificates': {
|
||||
domain: self._summarize_certificates(query_domain_certs),
|
||||
discovered_domain: self._summarize_certificates(discovered_domain_certs)
|
||||
}
|
||||
}
|
||||
|
||||
# Create domain -> domain relationship
|
||||
relationships.append((
|
||||
domain,
|
||||
subdomain,
|
||||
discovered_domain,
|
||||
RelationshipType.SAN_CERTIFICATE,
|
||||
RelationshipType.SAN_CERTIFICATE.default_confidence,
|
||||
raw_data
|
||||
confidence,
|
||||
relationship_raw_data
|
||||
))
|
||||
|
||||
# Log the relationship discovery
|
||||
self.log_relationship_discovery(
|
||||
source_node=domain,
|
||||
target_node=subdomain,
|
||||
target_node=discovered_domain,
|
||||
relationship_type=RelationshipType.SAN_CERTIFICATE,
|
||||
confidence_score=RelationshipType.SAN_CERTIFICATE.default_confidence,
|
||||
raw_data=raw_data,
|
||||
discovery_method="certificate_san_analysis"
|
||||
confidence_score=confidence,
|
||||
raw_data=relationship_raw_data,
|
||||
discovery_method="certificate_transparency_analysis"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -130,6 +288,165 @@ class CrtShProvider(BaseProvider):
|
||||
self.logger.logger.error(f"Error querying crt.sh for {domain}: {e}")
|
||||
|
||||
return relationships
|
||||
|
||||
def _find_shared_certificates(self, certs1: List[Dict[str, Any]], certs2: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Find certificates that are shared between two domain certificate lists.
|
||||
|
||||
Args:
|
||||
certs1: First domain's certificates
|
||||
certs2: Second domain's certificates
|
||||
|
||||
Returns:
|
||||
List of shared certificate metadata
|
||||
"""
|
||||
shared = []
|
||||
|
||||
# Create a set of certificate IDs from the first list for quick lookup
|
||||
cert1_ids = {cert.get('certificate_id') for cert in certs1 if cert.get('certificate_id')}
|
||||
|
||||
# Find certificates in the second list that match
|
||||
for cert in certs2:
|
||||
if cert.get('certificate_id') in cert1_ids:
|
||||
shared.append(cert)
|
||||
|
||||
return shared
|
||||
|
||||
def _summarize_certificates(self, certificates: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a summary of certificates for a domain.
|
||||
|
||||
Args:
|
||||
certificates: List of certificate metadata
|
||||
|
||||
Returns:
|
||||
Summary dictionary with aggregate statistics
|
||||
"""
|
||||
if not certificates:
|
||||
return {
|
||||
'total_certificates': 0,
|
||||
'valid_certificates': 0,
|
||||
'expired_certificates': 0,
|
||||
'expires_soon_count': 0,
|
||||
'unique_issuers': [],
|
||||
'latest_certificate': None,
|
||||
'has_valid_cert': False
|
||||
}
|
||||
|
||||
valid_count = sum(1 for cert in certificates if cert.get('is_currently_valid'))
|
||||
expired_count = len(certificates) - valid_count
|
||||
expires_soon_count = sum(1 for cert in certificates if cert.get('expires_soon'))
|
||||
|
||||
# Get unique issuers
|
||||
unique_issuers = list(set(cert.get('issuer_name') for cert in certificates if cert.get('issuer_name')))
|
||||
|
||||
# Find the most recent certificate
|
||||
latest_cert = None
|
||||
latest_date = None
|
||||
|
||||
for cert in certificates:
|
||||
try:
|
||||
if cert.get('not_before'):
|
||||
cert_date = self._parse_certificate_date(cert['not_before'])
|
||||
if latest_date is None or cert_date > latest_date:
|
||||
latest_date = cert_date
|
||||
latest_cert = cert
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {
|
||||
'total_certificates': len(certificates),
|
||||
'valid_certificates': valid_count,
|
||||
'expired_certificates': expired_count,
|
||||
'expires_soon_count': expires_soon_count,
|
||||
'unique_issuers': unique_issuers,
|
||||
'latest_certificate': latest_cert,
|
||||
'has_valid_cert': valid_count > 0,
|
||||
'certificate_details': certificates # Full details for forensic analysis
|
||||
}
|
||||
|
||||
def _calculate_domain_relationship_confidence(self, domain1: str, domain2: str,
|
||||
shared_certificates: List[Dict[str, Any]],
|
||||
all_discovered_domains: Set[str]) -> float:
|
||||
"""
|
||||
Calculate confidence score for domain relationship based on various factors.
|
||||
|
||||
Args:
|
||||
domain1: Source domain (query domain)
|
||||
domain2: Target domain (discovered domain)
|
||||
shared_certificates: List of shared certificate metadata
|
||||
all_discovered_domains: All domains discovered in this query
|
||||
|
||||
Returns:
|
||||
Confidence score between 0.0 and 1.0
|
||||
"""
|
||||
base_confidence = RelationshipType.SAN_CERTIFICATE.default_confidence
|
||||
|
||||
# Adjust confidence based on domain relationship context
|
||||
relationship_context = self._determine_relationship_context(domain2, domain1)
|
||||
|
||||
if relationship_context == 'exact_match':
|
||||
context_bonus = 0.0 # This shouldn't happen, but just in case
|
||||
elif relationship_context == 'subdomain':
|
||||
context_bonus = 0.1 # High confidence for subdomains
|
||||
elif relationship_context == 'parent_domain':
|
||||
context_bonus = 0.05 # Medium confidence for parent domains
|
||||
else:
|
||||
context_bonus = 0.0 # Related domains get base confidence
|
||||
|
||||
# Adjust confidence based on shared certificates
|
||||
if shared_certificates:
|
||||
shared_count = len(shared_certificates)
|
||||
if shared_count >= 3:
|
||||
shared_bonus = 0.1
|
||||
elif shared_count >= 2:
|
||||
shared_bonus = 0.05
|
||||
else:
|
||||
shared_bonus = 0.02
|
||||
|
||||
# Additional bonus for valid shared certificates
|
||||
valid_shared = sum(1 for cert in shared_certificates if cert.get('is_currently_valid'))
|
||||
if valid_shared > 0:
|
||||
validity_bonus = 0.05
|
||||
else:
|
||||
validity_bonus = 0.0
|
||||
else:
|
||||
# Even without shared certificates, domains found in the same query have some relationship
|
||||
shared_bonus = 0.0
|
||||
validity_bonus = 0.0
|
||||
|
||||
# Adjust confidence based on certificate issuer reputation (if shared certificates exist)
|
||||
issuer_bonus = 0.0
|
||||
if shared_certificates:
|
||||
for cert in shared_certificates:
|
||||
issuer = cert.get('issuer_name', '').lower()
|
||||
if any(trusted_ca in issuer for trusted_ca in ['let\'s encrypt', 'digicert', 'sectigo', 'globalsign']):
|
||||
issuer_bonus = max(issuer_bonus, 0.03)
|
||||
break
|
||||
|
||||
# Calculate final confidence
|
||||
final_confidence = base_confidence + context_bonus + shared_bonus + validity_bonus + issuer_bonus
|
||||
return max(0.1, min(1.0, final_confidence)) # Clamp between 0.1 and 1.0
|
||||
|
||||
def _determine_relationship_context(self, cert_domain: str, query_domain: str) -> str:
|
||||
"""
|
||||
Determine the context of the relationship between certificate domain and query domain.
|
||||
|
||||
Args:
|
||||
cert_domain: Domain found in certificate
|
||||
query_domain: Original query domain
|
||||
|
||||
Returns:
|
||||
String describing the relationship context
|
||||
"""
|
||||
if cert_domain == query_domain:
|
||||
return 'exact_match'
|
||||
elif cert_domain.endswith(f'.{query_domain}'):
|
||||
return 'subdomain'
|
||||
elif query_domain.endswith(f'.{cert_domain}'):
|
||||
return 'parent_domain'
|
||||
else:
|
||||
return 'related_domain'
|
||||
|
||||
def query_ip(self, ip: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
@@ -143,7 +460,6 @@ class CrtShProvider(BaseProvider):
|
||||
Empty list (crt.sh doesn't support IP-based certificate queries effectively)
|
||||
"""
|
||||
# crt.sh doesn't effectively support IP-based certificate queries
|
||||
# This would require parsing certificate details for IP SANs, which is complex
|
||||
return []
|
||||
|
||||
def _extract_domains_from_certificate(self, cert_data: Dict[str, Any]) -> Set[str]:
|
||||
@@ -162,7 +478,7 @@ class CrtShProvider(BaseProvider):
|
||||
common_name = cert_data.get('common_name', '')
|
||||
if common_name:
|
||||
cleaned_cn = self._clean_domain_name(common_name)
|
||||
if cleaned_cn and self._is_valid_domain(cleaned_cn):
|
||||
if cleaned_cn and _is_valid_domain(cleaned_cn):
|
||||
domains.add(cleaned_cn)
|
||||
|
||||
# Extract from name_value field (contains SANs)
|
||||
@@ -171,7 +487,7 @@ class CrtShProvider(BaseProvider):
|
||||
# Split by newlines and clean each domain
|
||||
for line in name_value.split('\n'):
|
||||
cleaned_domain = self._clean_domain_name(line.strip())
|
||||
if cleaned_domain and self._is_valid_domain(cleaned_domain):
|
||||
if cleaned_domain and _is_valid_domain(cleaned_domain):
|
||||
domains.add(cleaned_domain)
|
||||
|
||||
return domains
|
||||
@@ -215,70 +531,4 @@ class CrtShProvider(BaseProvider):
|
||||
if domain and not domain.startswith(('.', '-')) and not domain.endswith(('.', '-')):
|
||||
return domain
|
||||
|
||||
return ""
|
||||
|
||||
def get_certificate_details(self, certificate_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get detailed information about a specific certificate.
|
||||
|
||||
Args:
|
||||
certificate_id: Certificate ID from crt.sh
|
||||
|
||||
Returns:
|
||||
Dictionary containing certificate details
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}?id={certificate_id}&output=json"
|
||||
response = self.make_request(url, target_indicator=f"cert_{certificate_id}")
|
||||
|
||||
if response and response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error fetching certificate details for {certificate_id}: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
def search_certificates_by_serial(self, serial_number: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for certificates by serial number.
|
||||
|
||||
Args:
|
||||
serial_number: Certificate serial number
|
||||
|
||||
Returns:
|
||||
List of matching certificates
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}?serial={quote(serial_number)}&output=json"
|
||||
response = self.make_request(url, target_indicator=f"serial_{serial_number}")
|
||||
|
||||
if response and response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error searching certificates by serial {serial_number}: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def get_issuer_certificates(self, issuer_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get certificates issued by a specific CA.
|
||||
|
||||
Args:
|
||||
issuer_name: Certificate Authority name
|
||||
|
||||
Returns:
|
||||
List of certificates from the specified issuer
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}?issuer={quote(issuer_name)}&output=json"
|
||||
response = self.make_request(url, target_indicator=f"issuer_{issuer_name}")
|
||||
|
||||
if response and response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.logger.error(f"Error fetching certificates for issuer {issuer_name}: {e}")
|
||||
|
||||
return []
|
||||
return ""
|
||||
@@ -1,25 +1,26 @@
|
||||
# dnsrecon/providers/dns_provider.py
|
||||
|
||||
import socket
|
||||
import dns.resolver
|
||||
import dns.reversename
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from .base_provider import BaseProvider
|
||||
from core.graph_manager import RelationshipType, NodeType
|
||||
from utils.helpers import _is_valid_ip, _is_valid_domain
|
||||
from core.graph_manager import RelationshipType
|
||||
|
||||
|
||||
class DNSProvider(BaseProvider):
|
||||
"""
|
||||
Provider for standard DNS resolution and reverse DNS lookups.
|
||||
Discovers domain-to-IP and IP-to-domain relationships through DNS records.
|
||||
Now uses session-specific configuration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize DNS provider with appropriate rate limiting."""
|
||||
def __init__(self, session_config=None):
|
||||
"""Initialize DNS provider with session-specific configuration."""
|
||||
super().__init__(
|
||||
name="dns",
|
||||
rate_limit=100, # DNS queries can be faster
|
||||
timeout=10
|
||||
rate_limit=100,
|
||||
timeout=10,
|
||||
session_config=session_config
|
||||
)
|
||||
|
||||
# Configure DNS resolver
|
||||
@@ -45,7 +46,7 @@ class DNSProvider(BaseProvider):
|
||||
Returns:
|
||||
List of relationships discovered from DNS analysis
|
||||
"""
|
||||
if not self._is_valid_domain(domain):
|
||||
if not _is_valid_domain(domain):
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -66,7 +67,7 @@ class DNSProvider(BaseProvider):
|
||||
Returns:
|
||||
List of relationships discovered from reverse DNS
|
||||
"""
|
||||
if not self._is_valid_ip(ip):
|
||||
if not _is_valid_ip(ip):
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -81,7 +82,7 @@ class DNSProvider(BaseProvider):
|
||||
for ptr_record in response:
|
||||
hostname = str(ptr_record).rstrip('.')
|
||||
|
||||
if self._is_valid_domain(hostname):
|
||||
if _is_valid_domain(hostname):
|
||||
raw_data = {
|
||||
'query_type': 'PTR',
|
||||
'ip_address': ip,
|
||||
|
||||
@@ -4,38 +4,37 @@ Discovers IP relationships and infrastructure context through Shodan API.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from urllib.parse import quote
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from .base_provider import BaseProvider
|
||||
from utils.helpers import _is_valid_ip, _is_valid_domain
|
||||
from core.graph_manager import RelationshipType
|
||||
from config import config
|
||||
|
||||
|
||||
class ShodanProvider(BaseProvider):
|
||||
"""
|
||||
Provider for querying Shodan API for IP address and hostname information.
|
||||
Requires valid API key and respects Shodan's rate limits.
|
||||
Now uses session-specific API keys.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Shodan provider with appropriate rate limiting."""
|
||||
def __init__(self, session_config=None):
|
||||
"""Initialize Shodan provider with session-specific configuration."""
|
||||
super().__init__(
|
||||
name="shodan",
|
||||
rate_limit=60, # Shodan API has various rate limits depending on plan
|
||||
timeout=30
|
||||
rate_limit=60,
|
||||
timeout=30,
|
||||
session_config=session_config
|
||||
)
|
||||
self.base_url = "https://api.shodan.io"
|
||||
self.api_key = config.get_api_key('shodan')
|
||||
self.api_key = self.config.get_api_key('shodan')
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Shodan provider is available (has valid API key in this session)."""
|
||||
return self.api_key is not None and len(self.api_key.strip()) > 0
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "shodan"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Check if Shodan provider is available (has valid API key).
|
||||
"""
|
||||
return self.api_key is not None and len(self.api_key.strip()) > 0
|
||||
|
||||
|
||||
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
@@ -48,7 +47,7 @@ class ShodanProvider(BaseProvider):
|
||||
Returns:
|
||||
List of relationships discovered from Shodan data
|
||||
"""
|
||||
if not self._is_valid_domain(domain) or not self.is_available():
|
||||
if not _is_valid_domain(domain) or not self.is_available():
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -109,7 +108,7 @@ class ShodanProvider(BaseProvider):
|
||||
|
||||
# Also create relationships to other hostnames on the same IP
|
||||
for hostname in hostnames:
|
||||
if hostname != domain and self._is_valid_domain(hostname):
|
||||
if hostname != domain and _is_valid_domain(hostname):
|
||||
hostname_raw_data = {
|
||||
'shared_ip': ip_address,
|
||||
'all_hostnames': hostnames,
|
||||
@@ -150,7 +149,7 @@ class ShodanProvider(BaseProvider):
|
||||
Returns:
|
||||
List of relationships discovered from Shodan IP data
|
||||
"""
|
||||
if not self._is_valid_ip(ip) or not self.is_available():
|
||||
if not _is_valid_ip(ip) or not self.is_available():
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -170,7 +169,7 @@ class ShodanProvider(BaseProvider):
|
||||
# Extract hostname relationships
|
||||
hostnames = data.get('hostnames', [])
|
||||
for hostname in hostnames:
|
||||
if self._is_valid_domain(hostname):
|
||||
if _is_valid_domain(hostname):
|
||||
raw_data = {
|
||||
'ip_address': ip,
|
||||
'hostname': hostname,
|
||||
@@ -280,7 +279,7 @@ class ShodanProvider(BaseProvider):
|
||||
Returns:
|
||||
List of service information dictionaries
|
||||
"""
|
||||
if not self._is_valid_ip(ip) or not self.is_available():
|
||||
if not _is_valid_ip(ip) or not self.is_available():
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,38 +4,37 @@ Discovers domain relationships through passive DNS and URL analysis.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from .base_provider import BaseProvider
|
||||
from utils.helpers import _is_valid_ip, _is_valid_domain
|
||||
from core.graph_manager import RelationshipType
|
||||
from config import config
|
||||
|
||||
|
||||
class VirusTotalProvider(BaseProvider):
|
||||
"""
|
||||
Provider for querying VirusTotal API for passive DNS and domain reputation data.
|
||||
Requires valid API key and strictly respects free tier rate limits.
|
||||
Now uses session-specific API keys and rate limits.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize VirusTotal provider with strict rate limiting for free tier."""
|
||||
def __init__(self, session_config=None):
|
||||
"""Initialize VirusTotal provider with session-specific configuration."""
|
||||
super().__init__(
|
||||
name="virustotal",
|
||||
rate_limit=4, # Free tier: 4 requests per minute
|
||||
timeout=30
|
||||
timeout=30,
|
||||
session_config=session_config
|
||||
)
|
||||
self.base_url = "https://www.virustotal.com/vtapi/v2"
|
||||
self.api_key = config.get_api_key('virustotal')
|
||||
self.api_key = self.config.get_api_key('virustotal')
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if VirusTotal provider is available (has valid API key in this session)."""
|
||||
return self.api_key is not None and len(self.api_key.strip()) > 0
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "virustotal"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Check if VirusTotal provider is available (has valid API key).
|
||||
"""
|
||||
return self.api_key is not None and len(self.api_key.strip()) > 0
|
||||
|
||||
def query_domain(self, domain: str) -> List[Tuple[str, str, RelationshipType, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Query VirusTotal for domain information including passive DNS.
|
||||
@@ -46,7 +45,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
Returns:
|
||||
List of relationships discovered from VirusTotal data
|
||||
"""
|
||||
if not self._is_valid_domain(domain) or not self.is_available():
|
||||
if not _is_valid_domain(domain) or not self.is_available():
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -71,7 +70,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
Returns:
|
||||
List of relationships discovered from VirusTotal IP data
|
||||
"""
|
||||
if not self._is_valid_ip(ip) or not self.is_available():
|
||||
if not _is_valid_ip(ip) or not self.is_available():
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
@@ -114,7 +113,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
ip_address = resolution.get('ip_address')
|
||||
last_resolved = resolution.get('last_resolved')
|
||||
|
||||
if ip_address and self._is_valid_ip(ip_address):
|
||||
if ip_address and _is_valid_ip(ip_address):
|
||||
raw_data = {
|
||||
'domain': domain,
|
||||
'ip_address': ip_address,
|
||||
@@ -142,7 +141,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
# Extract subdomains
|
||||
subdomains = data.get('subdomains', [])
|
||||
for subdomain in subdomains:
|
||||
if subdomain != domain and self._is_valid_domain(subdomain):
|
||||
if subdomain != domain and _is_valid_domain(subdomain):
|
||||
raw_data = {
|
||||
'parent_domain': domain,
|
||||
'subdomain': subdomain,
|
||||
@@ -200,7 +199,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
hostname = resolution.get('hostname')
|
||||
last_resolved = resolution.get('last_resolved')
|
||||
|
||||
if hostname and self._is_valid_domain(hostname):
|
||||
if hostname and _is_valid_domain(hostname):
|
||||
raw_data = {
|
||||
'ip_address': ip,
|
||||
'hostname': hostname,
|
||||
@@ -254,7 +253,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
Returns:
|
||||
Dictionary containing reputation data
|
||||
"""
|
||||
if not self._is_valid_domain(domain) or not self.is_available():
|
||||
if not _is_valid_domain(domain) or not self.is_available():
|
||||
return {}
|
||||
|
||||
try:
|
||||
@@ -293,7 +292,7 @@ class VirusTotalProvider(BaseProvider):
|
||||
Returns:
|
||||
Dictionary containing reputation data
|
||||
"""
|
||||
if not self._is_valid_ip(ip) or not self.is_available():
|
||||
if not _is_valid_ip(ip) or not self.is_available():
|
||||
return {}
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user