bug fixes, improvements

This commit is contained in:
overcuriousity 2025-09-18 22:39:12 +02:00
parent 4c48917993
commit 95cebbf935
2 changed files with 462 additions and 129 deletions

View File

@ -189,10 +189,14 @@ class Scanner:
"""Initialize all available providers based on session configuration.""" """Initialize all available providers based on session configuration."""
self.providers = [] self.providers = []
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
for filename in os.listdir(provider_dir): for filename in os.listdir(provider_dir):
if filename.endswith('_provider.py') and not filename.startswith('base'): if filename.endswith('_provider.py') and not filename.startswith('base'):
module_name = f"providers.{filename[:-3]}" module_name = f"providers.{filename[:-3]}"
try: try:
print(f"Loading provider module: {module_name}")
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
for attribute_name in dir(module): for attribute_name in dir(module):
attribute = getattr(module, attribute_name) attribute = getattr(module, attribute_name)
@ -202,15 +206,41 @@ class Scanner:
provider = provider_class(name=attribute_name, session_config=self.config) provider = provider_class(name=attribute_name, session_config=self.config)
provider_name = provider.get_name() provider_name = provider.get_name()
print(f" Provider: {provider_name}")
print(f" Class: {provider_class.__name__}")
print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}")
print(f" Requires API key: {provider.requires_api_key()}")
if provider.requires_api_key():
api_key = self.config.get_api_key(provider_name)
print(f" API key present: {'Yes' if api_key else 'No'}")
if api_key:
print(f" API key preview: {api_key[:8]}...")
if self.config.is_provider_enabled(provider_name): if self.config.is_provider_enabled(provider_name):
if provider.is_available(): is_available = provider.is_available()
print(f" Available: {is_available}")
if is_available:
provider.set_stop_event(self.stop_event) provider.set_stop_event(self.stop_event)
if isinstance(provider, CorrelationProvider): if isinstance(provider, CorrelationProvider):
provider.set_graph_manager(self.graph) provider.set_graph_manager(self.graph)
self.providers.append(provider) self.providers.append(provider)
print(f" ✓ Added to scanner")
else:
print(f" ✗ Not available - skipped")
else:
print(f" ✗ Disabled in config - skipped")
except Exception as e: except Exception as e:
print(f" ERROR loading {module_name}: {e}")
traceback.print_exc() traceback.print_exc()
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
print(f"Active providers: {[p.get_name() for p in self.providers]}")
print(f"Provider count: {len(self.providers)}")
print("=" * 50)
def _status_logger_thread(self): def _status_logger_thread(self):
"""Periodically prints a clean, formatted scan status to the terminal.""" """Periodically prints a clean, formatted scan status to the terminal."""
HEADER = "\033[95m" HEADER = "\033[95m"
@ -424,6 +454,9 @@ class Scanner:
is_ip = _is_valid_ip(target) is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False) initial_providers = self._get_eligible_providers(target, is_ip, False)
# FIXED: Filter out correlation provider from initial providers
initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)]
for provider in initial_providers: for provider in initial_providers:
provider_name = provider.get_name() provider_name = provider.get_name()
priority = self._get_priority(provider_name) priority = self._get_priority(provider_name)
@ -443,6 +476,8 @@ class Scanner:
consecutive_empty_iterations = 0 consecutive_empty_iterations = 0
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
# PHASE 1: Run all non-correlation providers
print(f"\n=== PHASE 1: Running non-correlation providers ===")
while not self._is_stop_requested(): while not self._is_stop_requested():
queue_empty = self.task_queue.empty() queue_empty = self.task_queue.empty()
with self.processing_lock: with self.processing_lock:
@ -451,23 +486,25 @@ class Scanner:
if queue_empty and no_active_processing: if queue_empty and no_active_processing:
consecutive_empty_iterations += 1 consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations: if consecutive_empty_iterations >= max_empty_iterations:
break # Scan is complete break # Phase 1 complete
time.sleep(0.1) time.sleep(0.1)
continue continue
else: else:
consecutive_empty_iterations = 0 consecutive_empty_iterations = 0
# FIXED: Safe task retrieval without race conditions # Process tasks (same logic as before, but correlations are filtered out)
try: try:
# Use timeout to avoid blocking indefinitely
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# FIXED: Check if task is ready to run # Skip correlation tasks during Phase 1
if provider_name == 'correlation':
continue
# Check if task is ready to run
current_time = time.time() current_time = time.time()
if run_at > current_time: if run_at > current_time:
# Task is not ready yet, re-queue it and continue
self.task_queue.put((run_at, priority, (provider_name, target_item, depth))) self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time time.sleep(min(0.5, run_at - current_time))
continue continue
except: # Queue is empty or timeout occurred except: # Queue is empty or timeout occurred
@ -476,34 +513,32 @@ class Scanner:
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
# FIXED: Include depth in processed tasks to avoid incorrect skipping # Skip if already processed
task_tuple = (provider_name, target_item, depth) task_tuple = (provider_name, target_item, depth)
if task_tuple in processed_tasks: if task_tuple in processed_tasks:
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
continue continue
# FIXED: Proper depth checking # Skip if depth exceeded
if depth > max_depth: if depth > max_depth:
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
continue continue
# FIXED: Rate limiting with proper time-based deferral # Rate limiting with proper time-based deferral
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
defer_until = time.time() + 60 # Defer for 60 seconds defer_until = time.time() + 60
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth))) self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1 self.tasks_re_enqueued += 1
continue continue
# FIXED: Thread-safe processing state management # Thread-safe processing state management
with self.processing_lock: with self.processing_lock:
if self._is_stop_requested(): if self._is_stop_requested():
break break
# Use provider+target (without depth) for duplicate processing check
processing_key = (provider_name, target_item) processing_key = (provider_name, target_item)
if processing_key in self.currently_processing: if processing_key in self.currently_processing:
# Already processing this provider+target combination, skip
self.tasks_skipped += 1 self.tasks_skipped += 1
self.indicators_completed += 1 self.indicators_completed += 1
continue continue
@ -519,21 +554,19 @@ class Scanner:
provider = next((p for p in self.providers if p.get_name() == provider_name), None) provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider: if provider and not isinstance(provider, CorrelationProvider):
new_targets, _, success = self._process_provider_task(provider, target_item, depth) new_targets, _, success = self._process_provider_task(provider, target_item, depth)
if self._is_stop_requested(): if self._is_stop_requested():
break break
if not success: if not success:
# FIXED: Use depth-aware retry key
retry_key = (provider_name, target_item, depth) retry_key = (provider_name, target_item, depth)
self.target_retries[retry_key] += 1 self.target_retries[retry_key] += 1
if self.target_retries[retry_key] <= self.config.max_retries_per_target: if self.target_retries[retry_key] <= self.config.max_retries_per_target:
# FIXED: Exponential backoff with jitter for retries
retry_count = self.target_retries[retry_key] retry_count = self.target_retries[retry_key]
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
retry_at = time.time() + backoff_delay retry_at = time.time() + backoff_delay
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth))) self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1 self.tasks_re_enqueued += 1
@ -545,21 +578,21 @@ class Scanner:
processed_tasks.add(task_tuple) processed_tasks.add(task_tuple)
self.indicators_completed += 1 self.indicators_completed += 1
# FIXED: Enqueue new targets with proper depth tracking # Enqueue new targets with proper depth tracking
if not self._is_stop_requested(): if not self._is_stop_requested():
for new_target in new_targets: for new_target in new_targets:
is_ip_new = _is_valid_ip(new_target) is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
# FIXED: Filter out correlation providers in Phase 1
eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)]
for p_new in eligible_providers_new: for p_new in eligible_providers_new:
p_name_new = p_new.get_name() p_name_new = p_new.get_name()
new_depth = depth + 1 # Always increment depth for discovered targets new_depth = depth + 1
new_task_tuple = (p_name_new, new_target, new_depth) new_task_tuple = (p_name_new, new_target, new_depth)
# FIXED: Don't re-enqueue already processed tasks
if new_task_tuple not in processed_tasks and new_depth <= max_depth: if new_task_tuple not in processed_tasks and new_depth <= max_depth:
new_priority = self._get_priority(p_name_new) new_priority = self._get_priority(p_name_new)
# Enqueue new tasks to run immediately
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth))) self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
else: else:
@ -568,22 +601,25 @@ class Scanner:
self.indicators_completed += 1 self.indicators_completed += 1
finally: finally:
# FIXED: Always clean up processing state
with self.processing_lock: with self.processing_lock:
processing_key = (provider_name, target_item) processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key) self.currently_processing.discard(processing_key)
# PHASE 2: Run correlations on all discovered nodes
if not self._is_stop_requested():
print(f"\n=== PHASE 2: Running correlation analysis ===")
self._run_correlation_phase(max_depth, processed_tasks)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}") self.logger.logger.error(f"Scan failed: {e}")
finally: finally:
# FIXED: Comprehensive cleanup # Comprehensive cleanup (same as before)
with self.processing_lock: with self.processing_lock:
self.currently_processing.clear() self.currently_processing.clear()
self.currently_processing_display = [] self.currently_processing_display = []
# FIXED: Clear any remaining tasks from queue to prevent memory leaks
while not self.task_queue.empty(): while not self.task_queue.empty():
try: try:
self.task_queue.get_nowait() self.task_queue.get_nowait()
@ -599,7 +635,7 @@ class Scanner:
self.status_logger_stop_event.set() self.status_logger_stop_event.set()
if self.status_logger_thread and self.status_logger_thread.is_alive(): if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0) # Don't wait forever self.status_logger_thread.join(timeout=2.0)
self._update_session_state() self._update_session_state()
self.logger.log_scan_complete() self.logger.log_scan_complete()
@ -612,10 +648,120 @@ class Scanner:
finally: finally:
self.executor = None self.executor = None
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
"""
PHASE 2: Run correlation analysis on all discovered nodes.
This ensures correlations run after all other providers have completed.
"""
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if not correlation_provider:
print("No correlation provider found - skipping correlation phase")
return
# Get all nodes from the graph for correlation analysis
all_nodes = list(self.graph.graph.nodes())
correlation_tasks = []
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
for node_id in all_nodes:
if self._is_stop_requested():
break
# Determine appropriate depth for correlation (use 0 for simplicity)
correlation_depth = 0
task_tuple = ('correlation', node_id, correlation_depth)
# Don't re-process already processed correlation tasks
if task_tuple not in processed_tasks:
priority = self._get_priority('correlation')
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
correlation_tasks.append(task_tuple)
self.total_tasks_ever_enqueued += 1
print(f"Enqueued {len(correlation_tasks)} correlation tasks")
# Process correlation tasks
consecutive_empty_iterations = 0
max_empty_iterations = 20 # Shorter timeout for correlation phase
while not self._is_stop_requested() and correlation_tasks:
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
try:
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# Only process correlation tasks in this phase
if provider_name != 'correlation':
continue
except:
time.sleep(0.1)
continue
task_tuple = (provider_name, target_item, depth)
# Skip if already processed
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
continue
with self.processing_lock:
if self._is_stop_requested():
break
processing_key = (provider_name, target_item)
if processing_key in self.currently_processing:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
try:
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested():
break
# Process correlation task
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
if success:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
else:
# For correlations, don't retry - just mark as completed
self.indicators_completed += 1
if task_tuple in correlation_tasks:
correlation_tasks.remove(task_tuple)
finally:
with self.processing_lock:
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
""" """
Manages the entire process for a given target and provider. Manages the entire process for a given target and provider.
It uses the "worker" function to get the data and then manages the consequences. FIXED: Don't enqueue correlation tasks during normal processing.
""" """
if self._is_stop_requested(): if self._is_stop_requested():
return set(), set(), False return set(), set(), False
@ -644,14 +790,6 @@ class Scanner:
else: else:
new_targets.update(discovered) new_targets.update(discovered)
# After processing a provider, queue a correlation task for the target
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
if correlation_provider and not isinstance(provider, CorrelationProvider):
priority = self._get_priority(correlation_provider.get_name())
self.task_queue.put((time.time(), priority, (correlation_provider.get_name(), target, depth)))
# FIXED: Increment total tasks when a correlation task is enqueued
self.total_tasks_ever_enqueued += 1
except Exception as e: except Exception as e:
provider_successful = False provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e)) self._log_provider_error(target, provider.get_name(), str(e))
@ -690,7 +828,7 @@ class Scanner:
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
""" """
Process a unified ProviderResult object to update the graph. Process a unified ProviderResult object to update the graph.
VERIFIED: Proper ISP and CA node type assignment. FIXED: Ensure CA and ISP relationships are created even when large entities are formed.
""" """
provider_name = provider.get_name() provider_name = provider.get_name()
discovered_targets = set() discovered_targets = set()
@ -698,12 +836,70 @@ class Scanner:
if self._is_stop_requested(): if self._is_stop_requested():
return discovered_targets, False return discovered_targets, False
# Check if this should be a large entity # Check if this should be a large entity (only counting domain/IP relationships)
if provider_result.get_relationship_count() > self.config.large_entity_threshold: eligible_relationship_count = 0
for rel in provider_result.relationships:
# Only count relationships that would go into large entities
if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer':
continue # Don't count CA relationships
if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp':
continue # Don't count ISP relationships
if rel.relationship_type.startswith('corr_'):
continue # Don't count correlation relationships
# Only count domain/IP targets
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node):
eligible_relationship_count += 1
if eligible_relationship_count > self.config.large_entity_threshold:
# Create large entity but ALSO process special relationships
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth) members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
# FIXED: Still process CA, ISP, and correlation relationships directly on the graph
for relationship in provider_result.relationships:
if self._is_stop_requested():
break
source_node = relationship.source_node
target_node = relationship.target_node
# Process special relationship types that should appear directly on graph
should_create_direct_relationship = False
target_type = None
if provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
target_type = NodeType.CA
should_create_direct_relationship = True
elif provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
target_type = NodeType.ISP
should_create_direct_relationship = True
elif relationship.relationship_type.startswith('corr_'):
target_type = NodeType.CORRELATION_OBJECT
should_create_direct_relationship = True
if should_create_direct_relationship:
# Create source and target nodes
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
self.graph.add_node(source_node, source_type)
self.graph.add_node(target_node, target_type)
# Add the relationship edge
self.graph.add_edge(
source_node, target_node,
relationship.relationship_type,
relationship.confidence,
provider_name,
relationship.raw_data
)
# Add to discovered targets if it's a valid target for further processing
max_depth_reached = current_depth >= self.max_depth
if not max_depth_reached and (_is_valid_domain(target_node) or _is_valid_ip(target_node)):
discovered_targets.add(target_node)
return members, True return members, True
# Process relationships and create nodes with proper types # Normal processing (existing logic) when not creating large entity
for i, relationship in enumerate(provider_result.relationships): for i, relationship in enumerate(provider_result.relationships):
if i % 5 == 0 and self._is_stop_requested(): if i % 5 == 0 and self._is_stop_requested():
break break
@ -711,14 +907,14 @@ class Scanner:
source_node = relationship.source_node source_node = relationship.source_node
target_node = relationship.target_node target_node = relationship.target_node
# VERIFIED: Determine source node type # Determine source node type
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
# VERIFIED: Determine target node type based on provider and relationship # Determine target node type based on provider and relationship
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp': if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
target_type = NodeType.ISP # ISP node for Shodan organization data target_type = NodeType.ISP
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer': elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
target_type = NodeType.CA # CA node for certificate issuers target_type = NodeType.CA
elif provider_name == 'correlation': elif provider_name == 'correlation':
target_type = NodeType.CORRELATION_OBJECT target_type = NodeType.CORRELATION_OBJECT
elif _is_valid_ip(target_node): elif _is_valid_ip(target_node):
@ -733,7 +929,6 @@ class Scanner:
self.graph.add_node(source_node, source_type) self.graph.add_node(source_node, source_type)
self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached}) self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached})
# Add the relationship edge # Add the relationship edge
if self.graph.add_edge( if self.graph.add_edge(
source_node, target_node, source_node, target_node,
@ -748,7 +943,7 @@ class Scanner:
if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached: if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached:
discovered_targets.add(target_node) discovered_targets.add(target_node)
# Process all attributes, grouping by target node # Process all attributes (existing logic unchanged)
attributes_by_node = defaultdict(list) attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes: for attribute in provider_result.attributes:
attr_dict = { attr_dict = {
@ -764,11 +959,9 @@ class Scanner:
# Add attributes to existing nodes OR create new nodes if they don't exist # Add attributes to existing nodes OR create new nodes if they don't exist
for node_id, node_attributes_list in attributes_by_node.items(): for node_id, node_attributes_list in attributes_by_node.items():
if not self.graph.graph.has_node(node_id): if not self.graph.graph.has_node(node_id):
# If the node doesn't exist, create it with a default type
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list) self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
else: else:
# If the node already exists, just add the attributes
node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain') node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain')
self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list) self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list)
@ -778,25 +971,50 @@ class Scanner:
provider_result: ProviderResult, current_depth: int) -> Set[str]: provider_result: ProviderResult, current_depth: int) -> Set[str]:
""" """
Create a large entity node from a ProviderResult. Create a large entity node from a ProviderResult.
FIXED: Only include domain/IP nodes in large entities, exclude CA and other special node types.
""" """
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
targets = [rel.target_node for rel in provider_result.relationships] # FIXED: Filter out CA, ISP, and correlation nodes from large entity inclusion
node_type = 'unknown' eligible_targets = []
for rel in provider_result.relationships:
target_node = rel.target_node
if targets: # Skip CA nodes (certificate issuers) - they should appear directly on graph
if _is_valid_domain(targets[0]): if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer':
continue
# Skip ISP nodes - they should appear directly on graph
if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp':
continue
# Skip correlation objects - they should appear directly on graph
if rel.relationship_type.startswith('corr_'):
continue
# Only include valid domains and IPs in large entities
if _is_valid_domain(target_node) or _is_valid_ip(target_node):
eligible_targets.append(target_node)
# If no eligible targets after filtering, don't create large entity
if not eligible_targets:
return set()
node_type = 'unknown'
if eligible_targets:
if _is_valid_domain(eligible_targets[0]):
node_type = 'domain' node_type = 'domain'
elif _is_valid_ip(targets[0]): elif _is_valid_ip(eligible_targets[0]):
node_type = 'ip' node_type = 'ip'
for target in targets: # Create individual nodes for eligible targets
for target in eligible_targets:
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
self.graph.add_node(target, target_node_type) self.graph.add_node(target, target_node_type)
attributes_dict = { attributes_dict = {
'count': len(targets), 'count': len(eligible_targets),
'nodes': targets, 'nodes': eligible_targets, # Only eligible domain/IP nodes
'node_type': node_type, 'node_type': node_type,
'source_provider': provider_name, 'source_provider': provider_name,
'discovery_depth': current_depth, 'discovery_depth': current_depth,
@ -814,18 +1032,21 @@ class Scanner:
"metadata": {} "metadata": {}
}) })
description = f'Large entity created due to {len(targets)} relationships from {provider_name}' description = f'Large entity created due to {len(eligible_targets)} relationships from {provider_name}'
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description) self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
if provider_result.relationships: if provider_result.relationships:
rel_type = provider_result.relationships[0].relationship_type # Use the first eligible relationship for the large entity connection
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, eligible_rels = [rel for rel in provider_result.relationships if rel.target_node in eligible_targets]
{'large_entity_info': f'Contains {len(targets)} {node_type}s'}) if eligible_rels:
rel_type = eligible_rels[0].relationship_type
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
{'large_entity_info': f'Contains {len(eligible_targets)} {node_type}s'})
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(eligible_targets)} targets from {provider_name}")
return set(targets) return set(eligible_targets)
def stop_scan(self) -> bool: def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup.""" """Request immediate scan termination with proper cleanup."""
@ -857,6 +1078,7 @@ class Scanner:
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool: def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
""" """
Extracts a node from a large entity and re-queues it for scanning. Extracts a node from a large entity and re-queues it for scanning.
FIXED: Properly handle different node types during extraction.
""" """
if not self.graph.graph.has_node(large_entity_id): if not self.graph.graph.has_node(large_entity_id):
return False return False
@ -868,12 +1090,13 @@ class Scanner:
original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id) original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id)
if not original_edge_data: if not original_edge_data:
return False return False
success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract) success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract)
if not success: if not success:
return False return False
# Create relationship from source to extracted node
self.graph.add_edge( self.graph.add_edge(
source_id=source_node_id, source_id=source_node_id,
target_id=node_id_to_extract, target_id=node_id_to_extract,
@ -883,33 +1106,104 @@ class Scanner:
raw_data={'context': f'Extracted from large entity {large_entity_id}'} raw_data={'context': f'Extracted from large entity {large_entity_id}'}
) )
# FIXED: Only queue for further scanning if it's a domain/IP that can be scanned
is_ip = _is_valid_ip(node_id_to_extract) is_ip = _is_valid_ip(node_id_to_extract)
is_domain = _is_valid_domain(node_id_to_extract)
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', []) # Only queue valid domains and IPs for further processing
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None) # Don't queue CA nodes, ISP nodes, etc. as they can't be scanned
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0 if is_domain or is_ip:
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False)
for provider in eligible_providers: for provider in eligible_providers:
provider_name = provider.get_name() provider_name = provider.get_name()
priority = self._get_priority(provider_name) priority = self._get_priority(provider_name)
self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth))) self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth)))
self.total_tasks_ever_enqueued += 1 self.total_tasks_ever_enqueued += 1
if self.status != ScanStatus.RUNNING: if self.status != ScanStatus.RUNNING:
self.status = ScanStatus.RUNNING self.status = ScanStatus.RUNNING
self._update_session_state() self._update_session_state()
if not self.scan_thread or not self.scan_thread.is_alive(): if not self.scan_thread or not self.scan_thread.is_alive():
self.scan_thread = threading.Thread( self.scan_thread = threading.Thread(
target=self._execute_scan, target=self._execute_scan,
args=(self.current_target, self.max_depth), args=(self.current_target, self.max_depth),
daemon=True daemon=True
) )
self.scan_thread.start() self.scan_thread.start()
else:
# For non-scannable nodes (CA, ISP, etc.), just log that they were extracted
self.logger.logger.info(f"Extracted non-scannable node {node_id_to_extract} of type {self.graph.graph.nodes[node_id_to_extract].get('type', 'unknown')}")
return True return True
def _determine_extracted_node_type(self, node_id: str, large_entity_id: str) -> NodeType:
"""
FIXED: Determine the correct node type for a node being extracted from a large entity.
Uses multiple strategies to ensure accurate type detection.
"""
from utils.helpers import _is_valid_ip, _is_valid_domain
# Strategy 1: Check if node already exists in graph with a type
if self.graph.has_node(node_id):
existing_type = self.graph.nodes[node_id].get('type')
if existing_type:
try:
return NodeType(existing_type)
except ValueError:
pass
# Strategy 2: Look for existing relationships to this node to infer type
for source, target, edge_data in self.graph.edges(data=True):
if target == node_id:
rel_type = edge_data.get('relationship_type', '')
provider = edge_data.get('source_provider', '')
# CA nodes from certificate issuer relationships
if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer':
return NodeType.CA
# ISP nodes from Shodan
if provider == 'shodan' and rel_type == 'shodan_isp':
return NodeType.ISP
# Correlation objects
if rel_type.startswith('corr_'):
return NodeType.CORRELATION_OBJECT
if source == node_id:
rel_type = edge_data.get('relationship_type', '')
provider = edge_data.get('source_provider', '')
# Source nodes in cert issuer relationships are CAs
if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer':
return NodeType.CA
# Strategy 3: Format-based detection (fallback)
if _is_valid_ip(node_id):
return NodeType.IP
elif _is_valid_domain(node_id):
return NodeType.DOMAIN
# Strategy 4: Check large entity context
if self.graph.has_node(large_entity_id):
large_entity_data = self.graph.nodes[large_entity_id]
attributes = large_entity_data.get('attributes', [])
node_type_attr = next((attr for attr in attributes if attr.get('name') == 'node_type'), None)
if node_type_attr:
entity_node_type = node_type_attr.get('value', 'domain')
if entity_node_type == 'ip':
return NodeType.IP
else:
return NodeType.DOMAIN
# Final fallback
return NodeType.DOMAIN
def _update_session_state(self) -> None: def _update_session_state(self) -> None:
""" """
Update the scanner state in Redis for GUI updates. Update the scanner state in Redis for GUI updates.

View File

@ -27,26 +27,53 @@ class ShodanProvider(BaseProvider):
) )
self.base_url = "https://api.shodan.io" self.base_url = "https://api.shodan.io"
self.api_key = self.config.get_api_key('shodan') self.api_key = self.config.get_api_key('shodan')
self._is_active = self._check_api_connection()
# FIXED: Don't fail initialization on connection issues - defer to actual usage
self._connection_tested = False
self._connection_works = False
# Initialize cache directory # Initialize cache directory
self.cache_dir = Path('cache') / 'shodan' self.cache_dir = Path('cache') / 'shodan'
self.cache_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True)
def _check_api_connection(self) -> bool: def _check_api_connection(self) -> bool:
"""Checks if the Shodan API is reachable.""" """
FIXED: Lazy connection checking - only test when actually needed.
Don't block provider initialization on network issues.
"""
if self._connection_tested:
return self._connection_works
if not self.api_key: if not self.api_key:
return False self._connection_tested = True
try: self._connection_works = False
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
self.logger.logger.debug("Shodan is reacheable")
return response.status_code == 200
except requests.exceptions.RequestException:
return False return False
try:
print(f"Testing Shodan API connection with key: {self.api_key[:8]}...")
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
self._connection_works = response.status_code == 200
print(f"Shodan API test result: {response.status_code} - {'Success' if self._connection_works else 'Failed'}")
except requests.exceptions.RequestException as e:
print(f"Shodan API connection test failed: {e}")
self._connection_works = False
finally:
self._connection_tested = True
return self._connection_works
def is_available(self) -> bool: def is_available(self) -> bool:
"""Check if Shodan provider is available (has valid API key in this session).""" """
return self._is_active and self.api_key is not None and len(self.api_key.strip()) > 0 FIXED: Check if Shodan provider is available based on API key presence.
Don't require successful connection test during initialization.
"""
has_api_key = self.api_key is not None and len(self.api_key.strip()) > 0
if not has_api_key:
return False
# FIXED: Only test connection on first actual usage, not during initialization
return True
def get_name(self) -> str: def get_name(self) -> str:
"""Return the provider name.""" """Return the provider name."""
@ -117,6 +144,7 @@ class ShodanProvider(BaseProvider):
def query_ip(self, ip: str) -> ProviderResult: def query_ip(self, ip: str) -> ProviderResult:
""" """
Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data. Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data.
FIXED: Proper 404 handling to prevent unnecessary retries.
Args: Args:
ip: IP address to investigate (IPv4 or IPv6) ip: IP address to investigate (IPv4 or IPv6)
@ -127,7 +155,12 @@ class ShodanProvider(BaseProvider):
Raises: Raises:
Exception: For temporary failures that should be retried (timeouts, 502/503 errors, connection issues) Exception: For temporary failures that should be retried (timeouts, 502/503 errors, connection issues)
""" """
if not _is_valid_ip(ip) or not self.is_available(): if not _is_valid_ip(ip):
return ProviderResult()
# Test connection only when actually making requests
if not self._check_api_connection():
print(f"Shodan API not available for {ip} - API key: {'present' if self.api_key else 'missing'}")
return ProviderResult() return ProviderResult()
# Normalize IP address for consistent processing # Normalize IP address for consistent processing
@ -151,26 +184,40 @@ class ShodanProvider(BaseProvider):
response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip) response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip)
if not response: if not response:
# Connection failed - use stale cache if available, otherwise retry self.logger.logger.warning(f"Shodan API unreachable for {normalized_ip} - network failure")
if cache_status == "stale": if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to connection failure") self.logger.logger.info(f"Using stale cache for {normalized_ip} due to network failure")
return self._load_from_cache(cache_file) return self._load_from_cache(cache_file)
else: else:
raise requests.exceptions.RequestException("No response from Shodan API - should retry") # FIXED: Treat network failures as "no information" rather than retryable errors
self.logger.logger.info(f"No Shodan data available for {normalized_ip} due to network failure")
result = ProviderResult() # Empty result
network_failure_data = {'shodan_status': 'network_unreachable', 'error': 'API unreachable'}
self._save_to_cache(cache_file, result, network_failure_data)
return result
# FIXED: Handle different status codes more precisely
if response.status_code == 200: if response.status_code == 200:
self.logger.logger.debug(f"Shodan returned data for {normalized_ip}") self.logger.logger.debug(f"Shodan returned data for {normalized_ip}")
data = response.json() try:
result = self._process_shodan_data(normalized_ip, data) data = response.json()
self._save_to_cache(cache_file, result, data) result = self._process_shodan_data(normalized_ip, data)
return result self._save_to_cache(cache_file, result, data)
return result
except json.JSONDecodeError as e:
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}: {e}")
if cache_status == "stale":
return self._load_from_cache(cache_file)
else:
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
elif response.status_code == 404: elif response.status_code == 404:
# 404 = "no information available" - successful but empty result, don't retry # FIXED: 404 = "no information available" - successful but empty result, don't retry
self.logger.logger.debug(f"Shodan has no information for {normalized_ip} (404)") self.logger.logger.debug(f"Shodan has no information for {normalized_ip} (404)")
result = ProviderResult() # Empty but successful result result = ProviderResult() # Empty but successful result
# Cache the empty result to avoid repeated queries # Cache the empty result to avoid repeated queries
self._save_to_cache(cache_file, result, {'shodan_status': 'no_information', 'status_code': 404}) empty_data = {'shodan_status': 'no_information', 'status_code': 404}
self._save_to_cache(cache_file, result, empty_data)
return result return result
elif response.status_code in [401, 403]: elif response.status_code in [401, 403]:
@ -178,7 +225,7 @@ class ShodanProvider(BaseProvider):
self.logger.logger.error(f"Shodan API authentication failed for {normalized_ip} (HTTP {response.status_code})") self.logger.logger.error(f"Shodan API authentication failed for {normalized_ip} (HTTP {response.status_code})")
return ProviderResult() # Empty result, don't retry return ProviderResult() # Empty result, don't retry
elif response.status_code in [429]: elif response.status_code == 429:
# Rate limiting - should be handled by rate limiter, but if we get here, retry # Rate limiting - should be handled by rate limiter, but if we get here, retry
self.logger.logger.warning(f"Shodan API rate limited for {normalized_ip} (HTTP {response.status_code})") self.logger.logger.warning(f"Shodan API rate limited for {normalized_ip} (HTTP {response.status_code})")
if cache_status == "stale": if cache_status == "stale":
@ -197,13 +244,12 @@ class ShodanProvider(BaseProvider):
raise requests.exceptions.RequestException(f"Shodan API server error (HTTP {response.status_code}) - should retry") raise requests.exceptions.RequestException(f"Shodan API server error (HTTP {response.status_code}) - should retry")
else: else:
# Other HTTP error codes - treat as temporary failures # FIXED: Other HTTP status codes - treat as no information available, don't retry
self.logger.logger.warning(f"Shodan API returned unexpected status {response.status_code} for {normalized_ip}") self.logger.logger.info(f"Shodan returned status {response.status_code} for {normalized_ip} - treating as no information")
if cache_status == "stale": result = ProviderResult() # Empty result
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected API error") no_info_data = {'shodan_status': 'no_information', 'status_code': response.status_code}
return self._load_from_cache(cache_file) self._save_to_cache(cache_file, result, no_info_data)
else: return result
raise requests.exceptions.RequestException(f"Shodan API error (HTTP {response.status_code}) - should retry")
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
# Timeout errors - should be retried # Timeout errors - should be retried
@ -223,17 +269,8 @@ class ShodanProvider(BaseProvider):
else: else:
raise # Re-raise connection error for retry raise # Re-raise connection error for retry
except requests.exceptions.RequestException:
# Other request exceptions - should be retried
self.logger.logger.warning(f"Shodan API request exception for {normalized_ip}")
if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to request exception")
return self._load_from_cache(cache_file)
else:
raise # Re-raise request exception for retry
except json.JSONDecodeError: except json.JSONDecodeError:
# JSON parsing error on 200 response - treat as temporary failure # JSON parsing error - treat as temporary failure
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}") self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}")
if cache_status == "stale": if cache_status == "stale":
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to JSON parsing error") self.logger.logger.info(f"Using stale cache for {normalized_ip} due to JSON parsing error")
@ -241,14 +278,16 @@ class ShodanProvider(BaseProvider):
else: else:
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry") raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
# FIXED: Remove the generic RequestException handler that was causing 404s to retry
# Now only specific exceptions that should be retried are re-raised
except Exception as e: except Exception as e:
# Unexpected exceptions - log and treat as temporary failures # FIXED: Unexpected exceptions - log but treat as no information available, don't retry
self.logger.logger.error(f"Unexpected exception in Shodan query for {normalized_ip}: {e}") self.logger.logger.warning(f"Unexpected exception in Shodan query for {normalized_ip}: {e}")
if cache_status == "stale": result = ProviderResult() # Empty result
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected exception") error_data = {'shodan_status': 'error', 'error': str(e)}
return self._load_from_cache(cache_file) self._save_to_cache(cache_file, result, error_data)
else: return result
raise requests.exceptions.RequestException(f"Unexpected error in Shodan query: {e}") from e
def _load_from_cache(self, cache_file_path: Path) -> ProviderResult: def _load_from_cache(self, cache_file_path: Path) -> ProviderResult:
"""Load processed Shodan data from a cache file.""" """Load processed Shodan data from a cache file."""