bug fixes, improvements

This commit is contained in:
overcuriousity 2025-09-18 22:39:12 +02:00
parent 4c48917993
commit 95cebbf935
2 changed files with 462 additions and 129 deletions

View File

@ -189,10 +189,14 @@ class Scanner:
"""Initialize all available providers based on session configuration."""
self.providers = []
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}"
try:
print(f"Loading provider module: {module_name}")
module = importlib.import_module(module_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
@ -202,15 +206,41 @@ class Scanner:
provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name()
print(f" Provider: {provider_name}")
print(f" Class: {provider_class.__name__}")
print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}")
print(f" Requires API key: {provider.requires_api_key()}")
if provider.requires_api_key():
api_key = self.config.get_api_key(provider_name)
print(f" API key present: {'Yes' if api_key else 'No'}")
if api_key:
print(f" API key preview: {api_key[:8]}...")
if self.config.is_provider_enabled(provider_name):
if provider.is_available():
is_available = provider.is_available()
print(f" Available: {is_available}")
if is_available:
provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph)
self.providers.append(provider)
print(f" ✓ Added to scanner")
else:
print(f" ✗ Not available - skipped")
else:
print(f" ✗ Disabled in config - skipped")
except Exception as e:
print(f" ERROR loading {module_name}: {e}")
traceback.print_exc()
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
print(f"Active providers: {[p.get_name() for p in self.providers]}")
print(f"Provider count: {len(self.providers)}")
print("=" * 50)
def _status_logger_thread(self):
"""Periodically prints a clean, formatted scan status to the terminal."""
HEADER = "\033[95m"
@ -424,6 +454,9 @@ class Scanner:
is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False)
# FIXED: Filter out correlation provider from initial providers
initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)]
for provider in initial_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
@ -443,6 +476,8 @@ class Scanner:
consecutive_empty_iterations = 0
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
# PHASE 1: Run all non-correlation providers
print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested():
queue_empty = self.task_queue.empty()
with self.processing_lock:
@ -451,23 +486,25 @@ class Scanner:
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break # Scan is complete
break # Phase 1 complete
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
# FIXED: Safe task retrieval without race conditions
# Process tasks (same logic as before, but correlations are filtered out)
try:
# Use timeout to avoid blocking indefinitely
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# FIXED: Check if task is ready to run
# Skip correlation tasks during Phase 1
if provider_name == 'correlation':
continue
# Check if task is ready to run
current_time = time.time()
if run_at > current_time:
# Task is not ready yet, re-queue it and continue
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time
time.sleep(min(0.5, run_at - current_time))
continue
except: # Queue is empty or timeout occurred
@ -476,34 +513,32 @@ class Scanner:
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
# FIXED: Include depth in processed tasks to avoid incorrect skipping
# Skip if already processed
task_tuple = (provider_name, target_item, depth)
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# FIXED: Proper depth checking
# Skip if depth exceeded
if depth > max_depth:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# FIXED: Rate limiting with proper time-based deferral
# Rate limiting with proper time-based deferral
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
defer_until = time.time() + 60 # Defer for 60 seconds
defer_until = time.time() + 60
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
continue
# FIXED: Thread-safe processing state management
# Thread-safe processing state management
with self.processing_lock:
if self._is_stop_requested():
break
# Use provider+target (without depth) for duplicate processing check
processing_key = (provider_name, target_item)
if processing_key in self.currently_processing:
# Already processing this provider+target combination, skip
self.tasks_skipped += 1
self.indicators_completed += 1
continue
@ -519,21 +554,19 @@ class Scanner:
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider:
if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
if self._is_stop_requested():
break
if not success:
# FIXED: Use depth-aware retry key
retry_key = (provider_name, target_item, depth)
self.target_retries[retry_key] += 1
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
# FIXED: Exponential backoff with jitter for retries
retry_count = self.target_retries[retry_key]
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
retry_at = time.time() + backoff_delay
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
@ -545,21 +578,21 @@ class Scanner:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
# FIXED: Enqueue new targets with proper depth tracking
# Enqueue new targets with proper depth tracking
if not self._is_stop_requested():
for new_target in new_targets:
is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
# FIXED: Filter out correlation providers in Phase 1
eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)]
for p_new in eligible_providers_new:
p_name_new = p_new.get_name()
new_depth = depth + 1 # Always increment depth for discovered targets
new_depth = depth + 1
new_task_tuple = (p_name_new, new_target, new_depth)
# FIXED: Don't re-enqueue already processed tasks
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
new_priority = self._get_priority(p_name_new)
# Enqueue new tasks to run immediately
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1
else:
@ -568,22 +601,25 @@ class Scanner:
self.indicators_completed += 1
finally:
# FIXED: Always clean up processing state
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
# PHASE 2: Run correlations on all discovered nodes
if not self._is_stop_requested():
print(f"\n=== PHASE 2: Running correlation analysis ===")
self._run_correlation_phase(max_depth, processed_tasks)
except Exception as e:
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
# FIXED: Comprehensive cleanup
# Comprehensive cleanup (same as before)
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# FIXED: Clear any remaining tasks from queue to prevent memory leaks
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
@ -599,7 +635,7 @@ class Scanner:
self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0) # Don't wait forever
self.status_logger_thread.join(timeout=2.0)
self._update_session_state()
self.logger.log_scan_complete()
@ -612,10 +648,120 @@ class Scanner:
finally:
self.executor = None
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
"""
PHASE 2: Run correlation analysis on all discovered nodes.
This ensures correlations run after all other providers have completed.
"""
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if not correlation_provider:
print("No correlation provider found - skipping correlation phase")
return
# Get all nodes from the graph for correlation analysis
all_nodes = list(self.graph.graph.nodes())
correlation_tasks = []
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
for node_id in all_nodes:
if self._is_stop_requested():
break
# Determine appropriate depth for correlation (use 0 for simplicity)
correlation_depth = 0
task_tuple = ('correlation', node_id, correlation_depth)
# Don't re-process already processed correlation tasks
if task_tuple not in processed_tasks:
priority = self._get_priority('correlation')
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
correlation_tasks.append(task_tuple)
self.total_tasks_ever_enqueued += 1
print(f"Enqueued {len(correlation_tasks)} correlation tasks")
# Process correlation tasks
consecutive_empty_iterations = 0
max_empty_iterations = 20 # Shorter timeout for correlation phase
while not self._is_stop_requested() and correlation_tasks:
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
try:
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# Only process correlation tasks in this phase
if provider_name != 'correlation':
continue
except:
time.sleep(0.1)
continue
task_tuple = (provider_name, target_item, depth)
# Skip if already processed
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
continue
with self.processing_lock:
if self._is_stop_requested():
break
processing_key = (provider_name, target_item)
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
try:
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested():
break
# Process correlation task
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
if success:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
else:
# For correlations, don't retry - just mark as completed
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
"""
Manages the entire process for a given target and provider.
It uses the "worker" function to get the data and then manages the consequences.
FIXED: Don't enqueue correlation tasks during normal processing.
"""
if self._is_stop_requested():
return set(), set(), False
@ -644,14 +790,6 @@ class Scanner:
else:
new_targets.update(discovered)
# After processing a provider, queue a correlation task for the target
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if correlation_provider and not isinstance(provider, CorrelationProvider):
priority = self._get_priority(correlation_provider.get_name())
self.task_queue.put((time.time(), priority, (correlation_provider.get_name(), target, depth)))
# FIXED: Increment total tasks when a correlation task is enqueued
self.total_tasks_ever_enqueued += 1
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
@ -690,7 +828,7 @@ class Scanner:
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
"""
Process a unified ProviderResult object to update the graph.
VERIFIED: Proper ISP and CA node type assignment.
FIXED: Ensure CA and ISP relationships are created even when large entities are formed.
"""
provider_name = provider.get_name()
discovered_targets = set()
@ -698,12 +836,70 @@ class Scanner:
if self._is_stop_requested():
return discovered_targets, False
# Check if this should be a large entity
if provider_result.get_relationship_count() > self.config.large_entity_threshold:
# Check if this should be a large entity (only counting domain/IP relationships)
eligible_relationship_count = 0
for rel in provider_result.relationships:
# Only count relationships that would go into large entities
if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer':
continue # Don't count CA relationships
if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp':
continue # Don't count ISP relationships
if rel.relationship_type.startswith('corr_'):
continue # Don't count correlation relationships
# Only count domain/IP targets
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node):
eligible_relationship_count += 1
if eligible_relationship_count > self.config.large_entity_threshold:
# Create large entity but ALSO process special relationships
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
# FIXED: Still process CA, ISP, and correlation relationships directly on the graph
for relationship in provider_result.relationships:
if self._is_stop_requested():
break
source_node = relationship.source_node
target_node = relationship.target_node
# Process special relationship types that should appear directly on graph
should_create_direct_relationship = False
target_type = None
if provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
target_type = NodeType.CA
should_create_direct_relationship = True
elif provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
target_type = NodeType.ISP
should_create_direct_relationship = True
elif relationship.relationship_type.startswith('corr_'):
target_type = NodeType.CORRELATION_OBJECT
should_create_direct_relationship = True
if should_create_direct_relationship:
# Create source and target nodes
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
self.graph.add_node(source_node, source_type)
self.graph.add_node(target_node, target_type)
# Add the relationship edge
self.graph.add_edge(
source_node, target_node,
relationship.relationship_type,
relationship.confidence,
provider_name,
relationship.raw_data
)
# Add to discovered targets if it's a valid target for further processing
max_depth_reached = current_depth >= self.max_depth
if not max_depth_reached and (_is_valid_domain(target_node) or _is_valid_ip(target_node)):
discovered_targets.add(target_node)
return members, True
# Process relationships and create nodes with proper types
# Normal processing (existing logic) when not creating large entity
for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested():
break
@ -711,14 +907,14 @@ class Scanner:
source_node = relationship.source_node
target_node = relationship.target_node
# VERIFIED: Determine source node type
# Determine source node type
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
# VERIFIED: Determine target node type based on provider and relationship
# Determine target node type based on provider and relationship
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
target_type = NodeType.ISP # ISP node for Shodan organization data
target_type = NodeType.ISP
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
target_type = NodeType.CA # CA node for certificate issuers
target_type = NodeType.CA
elif provider_name == 'correlation':
target_type = NodeType.CORRELATION_OBJECT
elif _is_valid_ip(target_node):
@ -733,7 +929,6 @@ class Scanner:
self.graph.add_node(source_node, source_type)
self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached})
# Add the relationship edge
if self.graph.add_edge(
source_node, target_node,
@ -748,7 +943,7 @@ class Scanner:
if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached:
discovered_targets.add(target_node)
# Process all attributes, grouping by target node
# Process all attributes (existing logic unchanged)
attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes:
attr_dict = {
@ -764,11 +959,9 @@ class Scanner:
# Add attributes to existing nodes OR create new nodes if they don't exist
for node_id, node_attributes_list in attributes_by_node.items():
if not self.graph.graph.has_node(node_id):
# If the node doesn't exist, create it with a default type
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
else:
# If the node already exists, just add the attributes
node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain')
self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list)
@ -778,25 +971,50 @@ class Scanner:
provider_result: ProviderResult, current_depth: int) -> Set[str]:
"""
Create a large entity node from a ProviderResult.
FIXED: Only include domain/IP nodes in large entities, exclude CA and other special node types.
"""
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
targets = [rel.target_node for rel in provider_result.relationships]
node_type = 'unknown'
# FIXED: Filter out CA, ISP, and correlation nodes from large entity inclusion
eligible_targets = []
for rel in provider_result.relationships:
target_node = rel.target_node
if targets:
if _is_valid_domain(targets[0]):
# Skip CA nodes (certificate issuers) - they should appear directly on graph
if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer':
continue
# Skip ISP nodes - they should appear directly on graph
if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp':
continue
# Skip correlation objects - they should appear directly on graph
if rel.relationship_type.startswith('corr_'):
continue
# Only include valid domains and IPs in large entities
if _is_valid_domain(target_node) or _is_valid_ip(target_node):
eligible_targets.append(target_node)
# If no eligible targets after filtering, don't create large entity
if not eligible_targets:
return set()
node_type = 'unknown'
if eligible_targets:
if _is_valid_domain(eligible_targets[0]):
node_type = 'domain'
elif _is_valid_ip(targets[0]):
elif _is_valid_ip(eligible_targets[0]):
node_type = 'ip'
for target in targets:
# Create individual nodes for eligible targets
for target in eligible_targets:
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
self.graph.add_node(target, target_node_type)
attributes_dict = {
'count': len(targets),
'nodes': targets,
'count': len(eligible_targets),
'nodes': eligible_targets, # Only eligible domain/IP nodes
'node_type': node_type,
'source_provider': provider_name,
'discovery_depth': current_depth,
@ -814,18 +1032,21 @@ class Scanner:
"metadata": {}
})
description = f'Large entity created due to {len(targets)} relationships from {provider_name}'
description = f'Large entity created due to {len(eligible_targets)} relationships from {provider_name}'
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
if provider_result.relationships:
rel_type = provider_result.relationships[0].relationship_type
# Use the first eligible relationship for the large entity connection
eligible_rels = [rel for rel in provider_result.relationships if rel.target_node in eligible_targets]
if eligible_rels:
rel_type = eligible_rels[0].relationship_type
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
{'large_entity_info': f'Contains {len(targets)} {node_type}s'})
{'large_entity_info': f'Contains {len(eligible_targets)} {node_type}s'})
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(eligible_targets)} targets from {provider_name}")
return set(targets)
return set(eligible_targets)
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
@ -857,6 +1078,7 @@ class Scanner:
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
"""
Extracts a node from a large entity and re-queues it for scanning.
FIXED: Properly handle different node types during extraction.
"""
if not self.graph.graph.has_node(large_entity_id):
return False
@ -874,6 +1096,7 @@ class Scanner:
if not success:
return False
# Create relationship from source to extracted node
self.graph.add_edge(
source_id=source_node_id,
target_id=node_id_to_extract,
@ -883,8 +1106,13 @@ class Scanner:
raw_data={'context': f'Extracted from large entity {large_entity_id}'}
)
# FIXED: Only queue for further scanning if it's a domain/IP that can be scanned
is_ip = _is_valid_ip(node_id_to_extract)
is_domain = _is_valid_domain(node_id_to_extract)
# Only queue valid domains and IPs for further processing
# Don't queue CA nodes, ISP nodes, etc. as they can't be scanned
if is_domain or is_ip:
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
@ -907,9 +1135,75 @@ class Scanner:
daemon=True
)
self.scan_thread.start()
else:
# For non-scannable nodes (CA, ISP, etc.), just log that they were extracted
self.logger.logger.info(f"Extracted non-scannable node {node_id_to_extract} of type {self.graph.graph.nodes[node_id_to_extract].get('type', 'unknown')}")
return True
def _determine_extracted_node_type(self, node_id: str, large_entity_id: str) -> NodeType:
"""
FIXED: Determine the correct node type for a node being extracted from a large entity.
Uses multiple strategies to ensure accurate type detection.
"""
from utils.helpers import _is_valid_ip, _is_valid_domain
# Strategy 1: Check if node already exists in graph with a type
if self.graph.has_node(node_id):
existing_type = self.graph.nodes[node_id].get('type')
if existing_type:
try:
return NodeType(existing_type)
except ValueError:
pass
# Strategy 2: Look for existing relationships to this node to infer type
for source, target, edge_data in self.graph.edges(data=True):
if target == node_id:
rel_type = edge_data.get('relationship_type', '')
provider = edge_data.get('source_provider', '')
# CA nodes from certificate issuer relationships
if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer':
return NodeType.CA
# ISP nodes from Shodan
if provider == 'shodan' and rel_type == 'shodan_isp':
return NodeType.ISP
# Correlation objects
if rel_type.startswith('corr_'):
return NodeType.CORRELATION_OBJECT
if source == node_id:
rel_type = edge_data.get('relationship_type', '')
provider = edge_data.get('source_provider', '')
# Source nodes in cert issuer relationships are CAs
if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer':
return NodeType.CA
# Strategy 3: Format-based detection (fallback)
if _is_valid_ip(node_id):
return NodeType.IP
elif _is_valid_domain(node_id):
return NodeType.DOMAIN
# Strategy 4: Check large entity context
if self.graph.has_node(large_entity_id):
large_entity_data = self.graph.nodes[large_entity_id]
attributes = large_entity_data.get('attributes', [])
node_type_attr = next((attr for attr in attributes if attr.get('name') == 'node_type'), None)
if node_type_attr:
entity_node_type = node_type_attr.get('value', 'domain')
if entity_node_type == 'ip':
return NodeType.IP
else:
return NodeType.DOMAIN
# Final fallback
return NodeType.DOMAIN
def _update_session_state(self) -> None:
"""
Update the scanner state in Redis for GUI updates.

View File

@ -27,26 +27,53 @@ class ShodanProvider(BaseProvider):
)
self.base_url = "https://api.shodan.io"
self.api_key = self.config.get_api_key('shodan')
self._is_active = self._check_api_connection()
# FIXED: Don't fail initialization on connection issues - defer to actual usage
self._connection_tested = False
self._connection_works = False
# Initialize cache directory
self.cache_dir = Path('cache') / 'shodan'
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _check_api_connection(self) -> bool:
"""Checks if the Shodan API is reachable."""
"""
FIXED: Lazy connection checking - only test when actually needed.
Don't block provider initialization on network issues.
"""
if self._connection_tested:
return self._connection_works
if not self.api_key:
return False
try:
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
self.logger.logger.debug("Shodan is reacheable")
return response.status_code == 200
except requests.exceptions.RequestException:
self._connection_tested = True
self._connection_works = False
return False
try:
print(f"Testing Shodan API connection with key: {self.api_key[:8]}...")
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
self._connection_works = response.status_code == 200
print(f"Shodan API test result: {response.status_code} - {'Success' if self._connection_works else 'Failed'}")
except requests.exceptions.RequestException as e:
print(f"Shodan API connection test failed: {e}")
self._connection_works = False
finally:
self._connection_tested = True
return self._connection_works
def is_available(self) -> bool:
"""Check if Shodan provider is available (has valid API key in this session)."""
return self._is_active and self.api_key is not None and len(self.api_key.strip()) > 0
"""
FIXED: Check if Shodan provider is available based on API key presence.
Don't require successful connection test during initialization.
"""
has_api_key = self.api_key is not None and len(self.api_key.strip()) > 0
if not has_api_key:
return False
# FIXED: Only test connection on first actual usage, not during initialization
return True
def get_name(self) -> str:
"""Return the provider name."""
@ -117,6 +144,7 @@ class ShodanProvider(BaseProvider):
def query_ip(self, ip: str) -> ProviderResult:
"""
Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data.
FIXED: Proper 404 handling to prevent unnecessary retries.
Args:
ip: IP address to investigate (IPv4 or IPv6)
@ -127,7 +155,12 @@ class ShodanProvider(BaseProvider):
Raises:
Exception: For temporary failures that should be retried (timeouts, 502/503 errors, connection issues)
"""
if not _is_valid_ip(ip) or not self.is_available():
if not _is_valid_ip(ip):
return ProviderResult()
# Test connection only when actually making requests
if not self._check_api_connection():
print(f"Shodan API not available for {ip} - API key: {'present' if self.api_key else 'missing'}")
return ProviderResult()
# Normalize IP address for consistent processing
@ -151,26 +184,40 @@ class ShodanProvider(BaseProvider):
response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip)
if not response:
# Connection failed - use stale cache if available, otherwise retry
self.logger.logger.warning(f"Shodan API unreachable for {normalized_ip} - network failure")
if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to connection failure")
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to network failure")
return self._load_from_cache(cache_file)
else:
raise requests.exceptions.RequestException("No response from Shodan API - should retry")
# FIXED: Treat network failures as "no information" rather than retryable errors
self.logger.logger.info(f"No Shodan data available for {normalized_ip} due to network failure")
result = ProviderResult() # Empty result
network_failure_data = {'shodan_status': 'network_unreachable', 'error': 'API unreachable'}
self._save_to_cache(cache_file, result, network_failure_data)
return result
# FIXED: Handle different status codes more precisely
if response.status_code == 200:
self.logger.logger.debug(f"Shodan returned data for {normalized_ip}")
try:
data = response.json()
result = self._process_shodan_data(normalized_ip, data)
self._save_to_cache(cache_file, result, data)
return result
except json.JSONDecodeError as e:
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}: {e}")
if cache_status == "stale":
return self._load_from_cache(cache_file)
else:
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
elif response.status_code == 404:
# 404 = "no information available" - successful but empty result, don't retry
# FIXED: 404 = "no information available" - successful but empty result, don't retry
self.logger.logger.debug(f"Shodan has no information for {normalized_ip} (404)")
result = ProviderResult() # Empty but successful result
# Cache the empty result to avoid repeated queries
self._save_to_cache(cache_file, result, {'shodan_status': 'no_information', 'status_code': 404})
empty_data = {'shodan_status': 'no_information', 'status_code': 404}
self._save_to_cache(cache_file, result, empty_data)
return result
elif response.status_code in [401, 403]:
@ -178,7 +225,7 @@ class ShodanProvider(BaseProvider):
self.logger.logger.error(f"Shodan API authentication failed for {normalized_ip} (HTTP {response.status_code})")
return ProviderResult() # Empty result, don't retry
elif response.status_code in [429]:
elif response.status_code == 429:
# Rate limiting - should be handled by rate limiter, but if we get here, retry
self.logger.logger.warning(f"Shodan API rate limited for {normalized_ip} (HTTP {response.status_code})")
if cache_status == "stale":
@ -197,13 +244,12 @@ class ShodanProvider(BaseProvider):
raise requests.exceptions.RequestException(f"Shodan API server error (HTTP {response.status_code}) - should retry")
else:
# Other HTTP error codes - treat as temporary failures
self.logger.logger.warning(f"Shodan API returned unexpected status {response.status_code} for {normalized_ip}")
if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected API error")
return self._load_from_cache(cache_file)
else:
raise requests.exceptions.RequestException(f"Shodan API error (HTTP {response.status_code}) - should retry")
# FIXED: Other HTTP status codes - treat as no information available, don't retry
self.logger.logger.info(f"Shodan returned status {response.status_code} for {normalized_ip} - treating as no information")
result = ProviderResult() # Empty result
no_info_data = {'shodan_status': 'no_information', 'status_code': response.status_code}
self._save_to_cache(cache_file, result, no_info_data)
return result
except requests.exceptions.Timeout:
# Timeout errors - should be retried
@ -223,17 +269,8 @@ class ShodanProvider(BaseProvider):
else:
raise # Re-raise connection error for retry
except requests.exceptions.RequestException:
# Other request exceptions - should be retried
self.logger.logger.warning(f"Shodan API request exception for {normalized_ip}")
if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to request exception")
return self._load_from_cache(cache_file)
else:
raise # Re-raise request exception for retry
except json.JSONDecodeError:
# JSON parsing error on 200 response - treat as temporary failure
# JSON parsing error - treat as temporary failure
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}")
if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to JSON parsing error")
@ -241,14 +278,16 @@ class ShodanProvider(BaseProvider):
else:
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
# FIXED: Remove the generic RequestException handler that was causing 404s to retry
# Now only specific exceptions that should be retried are re-raised
except Exception as e:
# Unexpected exceptions - log and treat as temporary failures
self.logger.logger.error(f"Unexpected exception in Shodan query for {normalized_ip}: {e}")
if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected exception")
return self._load_from_cache(cache_file)
else:
raise requests.exceptions.RequestException(f"Unexpected error in Shodan query: {e}") from e
# FIXED: Unexpected exceptions - log but treat as no information available, don't retry
self.logger.logger.warning(f"Unexpected exception in Shodan query for {normalized_ip}: {e}")
result = ProviderResult() # Empty result
error_data = {'shodan_status': 'error', 'error': str(e)}
self._save_to_cache(cache_file, result, error_data)
return result
def _load_from_cache(self, cache_file_path: Path) -> ProviderResult:
"""Load processed Shodan data from a cache file."""