From 95cebbf935739cf82bc589902fd93f64895ca236 Mon Sep 17 00:00:00 2001 From: overcuriousity Date: Thu, 18 Sep 2025 22:39:12 +0200 Subject: [PATCH] bug fixes, improvements --- core/scanner.py | 466 ++++++++++++++++++++++++++++------- providers/shodan_provider.py | 125 ++++++---- 2 files changed, 462 insertions(+), 129 deletions(-) diff --git a/core/scanner.py b/core/scanner.py index 0c8f474..2d87af8 100644 --- a/core/scanner.py +++ b/core/scanner.py @@ -189,10 +189,14 @@ class Scanner: """Initialize all available providers based on session configuration.""" self.providers = [] provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers') + + print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===") + for filename in os.listdir(provider_dir): if filename.endswith('_provider.py') and not filename.startswith('base'): module_name = f"providers.{filename[:-3]}" try: + print(f"Loading provider module: {module_name}") module = importlib.import_module(module_name) for attribute_name in dir(module): attribute = getattr(module, attribute_name) @@ -202,15 +206,41 @@ class Scanner: provider = provider_class(name=attribute_name, session_config=self.config) provider_name = provider.get_name() + print(f" Provider: {provider_name}") + print(f" Class: {provider_class.__name__}") + print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}") + print(f" Requires API key: {provider.requires_api_key()}") + + if provider.requires_api_key(): + api_key = self.config.get_api_key(provider_name) + print(f" API key present: {'Yes' if api_key else 'No'}") + if api_key: + print(f" API key preview: {api_key[:8]}...") + if self.config.is_provider_enabled(provider_name): - if provider.is_available(): + is_available = provider.is_available() + print(f" Available: {is_available}") + + if is_available: provider.set_stop_event(self.stop_event) if isinstance(provider, CorrelationProvider): provider.set_graph_manager(self.graph) self.providers.append(provider) + print(f" ✓ Added to scanner") + else: + print(f" ✗ Not available - skipped") + else: + print(f" ✗ Disabled in config - skipped") + except Exception as e: + print(f" ERROR loading {module_name}: {e}") traceback.print_exc() + print(f"=== PROVIDER INITIALIZATION COMPLETE ===") + print(f"Active providers: {[p.get_name() for p in self.providers]}") + print(f"Provider count: {len(self.providers)}") + print("=" * 50) + def _status_logger_thread(self): """Periodically prints a clean, formatted scan status to the terminal.""" HEADER = "\033[95m" @@ -424,6 +454,9 @@ class Scanner: is_ip = _is_valid_ip(target) initial_providers = self._get_eligible_providers(target, is_ip, False) + # FIXED: Filter out correlation provider from initial providers + initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)] + for provider in initial_providers: provider_name = provider.get_name() priority = self._get_priority(provider_name) @@ -443,6 +476,8 @@ class Scanner: consecutive_empty_iterations = 0 max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion + # PHASE 1: Run all non-correlation providers + print(f"\n=== PHASE 1: Running non-correlation providers ===") while not self._is_stop_requested(): queue_empty = self.task_queue.empty() with self.processing_lock: @@ -451,59 +486,59 @@ class Scanner: if queue_empty and no_active_processing: consecutive_empty_iterations += 1 if consecutive_empty_iterations >= max_empty_iterations: - break # Scan is complete + break # Phase 1 complete time.sleep(0.1) continue else: consecutive_empty_iterations = 0 - # FIXED: Safe task retrieval without race conditions + # Process tasks (same logic as before, but correlations are filtered out) try: - # Use timeout to avoid blocking indefinitely run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) - # FIXED: Check if task is ready to run + # Skip correlation tasks during Phase 1 + if provider_name == 'correlation': + continue + + # Check if task is ready to run current_time = time.time() if run_at > current_time: - # Task is not ready yet, re-queue it and continue self.task_queue.put((run_at, priority, (provider_name, target_item, depth))) - time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time + time.sleep(min(0.5, run_at - current_time)) continue - + except: # Queue is empty or timeout occurred time.sleep(0.1) continue self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth)) - # FIXED: Include depth in processed tasks to avoid incorrect skipping + # Skip if already processed task_tuple = (provider_name, target_item, depth) if task_tuple in processed_tasks: self.tasks_skipped += 1 self.indicators_completed += 1 continue - # FIXED: Proper depth checking + # Skip if depth exceeded if depth > max_depth: self.tasks_skipped += 1 self.indicators_completed += 1 continue - # FIXED: Rate limiting with proper time-based deferral + # Rate limiting with proper time-based deferral if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60): - defer_until = time.time() + 60 # Defer for 60 seconds + defer_until = time.time() + 60 self.task_queue.put((defer_until, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 continue - # FIXED: Thread-safe processing state management + # Thread-safe processing state management with self.processing_lock: if self._is_stop_requested(): break - # Use provider+target (without depth) for duplicate processing check processing_key = (provider_name, target_item) if processing_key in self.currently_processing: - # Already processing this provider+target combination, skip self.tasks_skipped += 1 self.indicators_completed += 1 continue @@ -519,21 +554,19 @@ class Scanner: provider = next((p for p in self.providers if p.get_name() == provider_name), None) - if provider: + if provider and not isinstance(provider, CorrelationProvider): new_targets, _, success = self._process_provider_task(provider, target_item, depth) if self._is_stop_requested(): break if not success: - # FIXED: Use depth-aware retry key retry_key = (provider_name, target_item, depth) self.target_retries[retry_key] += 1 if self.target_retries[retry_key] <= self.config.max_retries_per_target: - # FIXED: Exponential backoff with jitter for retries retry_count = self.target_retries[retry_key] - backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes + backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) retry_at = time.time() + backoff_delay self.task_queue.put((retry_at, priority, (provider_name, target_item, depth))) self.tasks_re_enqueued += 1 @@ -545,21 +578,21 @@ class Scanner: processed_tasks.add(task_tuple) self.indicators_completed += 1 - # FIXED: Enqueue new targets with proper depth tracking + # Enqueue new targets with proper depth tracking if not self._is_stop_requested(): for new_target in new_targets: is_ip_new = _is_valid_ip(new_target) eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False) + # FIXED: Filter out correlation providers in Phase 1 + eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)] for p_new in eligible_providers_new: p_name_new = p_new.get_name() - new_depth = depth + 1 # Always increment depth for discovered targets + new_depth = depth + 1 new_task_tuple = (p_name_new, new_target, new_depth) - # FIXED: Don't re-enqueue already processed tasks if new_task_tuple not in processed_tasks and new_depth <= max_depth: new_priority = self._get_priority(p_name_new) - # Enqueue new tasks to run immediately self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth))) self.total_tasks_ever_enqueued += 1 else: @@ -568,22 +601,25 @@ class Scanner: self.indicators_completed += 1 finally: - # FIXED: Always clean up processing state with self.processing_lock: processing_key = (provider_name, target_item) self.currently_processing.discard(processing_key) + # PHASE 2: Run correlations on all discovered nodes + if not self._is_stop_requested(): + print(f"\n=== PHASE 2: Running correlation analysis ===") + self._run_correlation_phase(max_depth, processed_tasks) + except Exception as e: traceback.print_exc() self.status = ScanStatus.FAILED self.logger.logger.error(f"Scan failed: {e}") finally: - # FIXED: Comprehensive cleanup + # Comprehensive cleanup (same as before) with self.processing_lock: self.currently_processing.clear() self.currently_processing_display = [] - # FIXED: Clear any remaining tasks from queue to prevent memory leaks while not self.task_queue.empty(): try: self.task_queue.get_nowait() @@ -599,7 +635,7 @@ class Scanner: self.status_logger_stop_event.set() if self.status_logger_thread and self.status_logger_thread.is_alive(): - self.status_logger_thread.join(timeout=2.0) # Don't wait forever + self.status_logger_thread.join(timeout=2.0) self._update_session_state() self.logger.log_scan_complete() @@ -612,10 +648,120 @@ class Scanner: finally: self.executor = None + def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None: + """ + PHASE 2: Run correlation analysis on all discovered nodes. + This ensures correlations run after all other providers have completed. + """ + correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None) + if not correlation_provider: + print("No correlation provider found - skipping correlation phase") + return + + # Get all nodes from the graph for correlation analysis + all_nodes = list(self.graph.graph.nodes()) + correlation_tasks = [] + + print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes") + + for node_id in all_nodes: + if self._is_stop_requested(): + break + + # Determine appropriate depth for correlation (use 0 for simplicity) + correlation_depth = 0 + task_tuple = ('correlation', node_id, correlation_depth) + + # Don't re-process already processed correlation tasks + if task_tuple not in processed_tasks: + priority = self._get_priority('correlation') + self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth))) + correlation_tasks.append(task_tuple) + self.total_tasks_ever_enqueued += 1 + + print(f"Enqueued {len(correlation_tasks)} correlation tasks") + + # Process correlation tasks + consecutive_empty_iterations = 0 + max_empty_iterations = 20 # Shorter timeout for correlation phase + + while not self._is_stop_requested() and correlation_tasks: + queue_empty = self.task_queue.empty() + with self.processing_lock: + no_active_processing = len(self.currently_processing) == 0 + + if queue_empty and no_active_processing: + consecutive_empty_iterations += 1 + if consecutive_empty_iterations >= max_empty_iterations: + break + time.sleep(0.1) + continue + else: + consecutive_empty_iterations = 0 + + try: + run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1) + + # Only process correlation tasks in this phase + if provider_name != 'correlation': + continue + + except: + time.sleep(0.1) + continue + + task_tuple = (provider_name, target_item, depth) + + # Skip if already processed + if task_tuple in processed_tasks: + self.tasks_skipped += 1 + self.indicators_completed += 1 + if task_tuple in correlation_tasks: + correlation_tasks.remove(task_tuple) + continue + + with self.processing_lock: + if self._is_stop_requested(): + break + processing_key = (provider_name, target_item) + if processing_key in self.currently_processing: + self.tasks_skipped += 1 + self.indicators_completed += 1 + continue + self.currently_processing.add(processing_key) + + try: + self.current_indicator = target_item + self._update_session_state() + + if self._is_stop_requested(): + break + + # Process correlation task + new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth) + + if success: + processed_tasks.add(task_tuple) + self.indicators_completed += 1 + if task_tuple in correlation_tasks: + correlation_tasks.remove(task_tuple) + else: + # For correlations, don't retry - just mark as completed + self.indicators_completed += 1 + if task_tuple in correlation_tasks: + correlation_tasks.remove(task_tuple) + + finally: + with self.processing_lock: + processing_key = (provider_name, target_item) + self.currently_processing.discard(processing_key) + + print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}") + def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]: """ Manages the entire process for a given target and provider. - It uses the "worker" function to get the data and then manages the consequences. + FIXED: Don't enqueue correlation tasks during normal processing. """ if self._is_stop_requested(): return set(), set(), False @@ -643,14 +789,6 @@ class Scanner: large_entity_members.update(discovered) else: new_targets.update(discovered) - - # After processing a provider, queue a correlation task for the target - correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None) - if correlation_provider and not isinstance(provider, CorrelationProvider): - priority = self._get_priority(correlation_provider.get_name()) - self.task_queue.put((time.time(), priority, (correlation_provider.get_name(), target, depth))) - # FIXED: Increment total tasks when a correlation task is enqueued - self.total_tasks_ever_enqueued += 1 except Exception as e: provider_successful = False @@ -690,7 +828,7 @@ class Scanner: provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]: """ Process a unified ProviderResult object to update the graph. - VERIFIED: Proper ISP and CA node type assignment. + FIXED: Ensure CA and ISP relationships are created even when large entities are formed. """ provider_name = provider.get_name() discovered_targets = set() @@ -698,12 +836,70 @@ class Scanner: if self._is_stop_requested(): return discovered_targets, False - # Check if this should be a large entity - if provider_result.get_relationship_count() > self.config.large_entity_threshold: + # Check if this should be a large entity (only counting domain/IP relationships) + eligible_relationship_count = 0 + for rel in provider_result.relationships: + # Only count relationships that would go into large entities + if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer': + continue # Don't count CA relationships + if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp': + continue # Don't count ISP relationships + if rel.relationship_type.startswith('corr_'): + continue # Don't count correlation relationships + + # Only count domain/IP targets + if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node): + eligible_relationship_count += 1 + + if eligible_relationship_count > self.config.large_entity_threshold: + # Create large entity but ALSO process special relationships members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth) + + # FIXED: Still process CA, ISP, and correlation relationships directly on the graph + for relationship in provider_result.relationships: + if self._is_stop_requested(): + break + + source_node = relationship.source_node + target_node = relationship.target_node + + # Process special relationship types that should appear directly on graph + should_create_direct_relationship = False + target_type = None + + if provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer': + target_type = NodeType.CA + should_create_direct_relationship = True + elif provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp': + target_type = NodeType.ISP + should_create_direct_relationship = True + elif relationship.relationship_type.startswith('corr_'): + target_type = NodeType.CORRELATION_OBJECT + should_create_direct_relationship = True + + if should_create_direct_relationship: + # Create source and target nodes + source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN + self.graph.add_node(source_node, source_type) + self.graph.add_node(target_node, target_type) + + # Add the relationship edge + self.graph.add_edge( + source_node, target_node, + relationship.relationship_type, + relationship.confidence, + provider_name, + relationship.raw_data + ) + + # Add to discovered targets if it's a valid target for further processing + max_depth_reached = current_depth >= self.max_depth + if not max_depth_reached and (_is_valid_domain(target_node) or _is_valid_ip(target_node)): + discovered_targets.add(target_node) + return members, True - - # Process relationships and create nodes with proper types + + # Normal processing (existing logic) when not creating large entity for i, relationship in enumerate(provider_result.relationships): if i % 5 == 0 and self._is_stop_requested(): break @@ -711,14 +907,14 @@ class Scanner: source_node = relationship.source_node target_node = relationship.target_node - # VERIFIED: Determine source node type + # Determine source node type source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN - # VERIFIED: Determine target node type based on provider and relationship + # Determine target node type based on provider and relationship if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp': - target_type = NodeType.ISP # ISP node for Shodan organization data + target_type = NodeType.ISP elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer': - target_type = NodeType.CA # CA node for certificate issuers + target_type = NodeType.CA elif provider_name == 'correlation': target_type = NodeType.CORRELATION_OBJECT elif _is_valid_ip(target_node): @@ -733,7 +929,6 @@ class Scanner: self.graph.add_node(source_node, source_type) self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached}) - # Add the relationship edge if self.graph.add_edge( source_node, target_node, @@ -748,7 +943,7 @@ class Scanner: if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached: discovered_targets.add(target_node) - # Process all attributes, grouping by target node + # Process all attributes (existing logic unchanged) attributes_by_node = defaultdict(list) for attribute in provider_result.attributes: attr_dict = { @@ -764,11 +959,9 @@ class Scanner: # Add attributes to existing nodes OR create new nodes if they don't exist for node_id, node_attributes_list in attributes_by_node.items(): if not self.graph.graph.has_node(node_id): - # If the node doesn't exist, create it with a default type node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN self.graph.add_node(node_id, node_type, attributes=node_attributes_list) else: - # If the node already exists, just add the attributes node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain') self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list) @@ -778,25 +971,50 @@ class Scanner: provider_result: ProviderResult, current_depth: int) -> Set[str]: """ Create a large entity node from a ProviderResult. + FIXED: Only include domain/IP nodes in large entities, exclude CA and other special node types. """ entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}" - targets = [rel.target_node for rel in provider_result.relationships] - node_type = 'unknown' + # FIXED: Filter out CA, ISP, and correlation nodes from large entity inclusion + eligible_targets = [] + for rel in provider_result.relationships: + target_node = rel.target_node + + # Skip CA nodes (certificate issuers) - they should appear directly on graph + if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer': + continue + + # Skip ISP nodes - they should appear directly on graph + if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp': + continue + + # Skip correlation objects - they should appear directly on graph + if rel.relationship_type.startswith('corr_'): + continue + + # Only include valid domains and IPs in large entities + if _is_valid_domain(target_node) or _is_valid_ip(target_node): + eligible_targets.append(target_node) - if targets: - if _is_valid_domain(targets[0]): + # If no eligible targets after filtering, don't create large entity + if not eligible_targets: + return set() + + node_type = 'unknown' + if eligible_targets: + if _is_valid_domain(eligible_targets[0]): node_type = 'domain' - elif _is_valid_ip(targets[0]): + elif _is_valid_ip(eligible_targets[0]): node_type = 'ip' - for target in targets: + # Create individual nodes for eligible targets + for target in eligible_targets: target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP self.graph.add_node(target, target_node_type) attributes_dict = { - 'count': len(targets), - 'nodes': targets, + 'count': len(eligible_targets), + 'nodes': eligible_targets, # Only eligible domain/IP nodes 'node_type': node_type, 'source_provider': provider_name, 'discovery_depth': current_depth, @@ -814,18 +1032,21 @@ class Scanner: "metadata": {} }) - description = f'Large entity created due to {len(targets)} relationships from {provider_name}' + description = f'Large entity created due to {len(eligible_targets)} relationships from {provider_name}' self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description) if provider_result.relationships: - rel_type = provider_result.relationships[0].relationship_type - self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, - {'large_entity_info': f'Contains {len(targets)} {node_type}s'}) + # Use the first eligible relationship for the large entity connection + eligible_rels = [rel for rel in provider_result.relationships if rel.target_node in eligible_targets] + if eligible_rels: + rel_type = eligible_rels[0].relationship_type + self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name, + {'large_entity_info': f'Contains {len(eligible_targets)} {node_type}s'}) - self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}") + self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(eligible_targets)} targets from {provider_name}") - return set(targets) + return set(eligible_targets) def stop_scan(self) -> bool: """Request immediate scan termination with proper cleanup.""" @@ -857,6 +1078,7 @@ class Scanner: def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool: """ Extracts a node from a large entity and re-queues it for scanning. + FIXED: Properly handle different node types during extraction. """ if not self.graph.graph.has_node(large_entity_id): return False @@ -868,12 +1090,13 @@ class Scanner: original_edge_data = self.graph.graph.get_edge_data(source_node_id, large_entity_id) if not original_edge_data: - return False + return False success = self.graph.extract_node_from_large_entity(large_entity_id, node_id_to_extract) if not success: return False + # Create relationship from source to extracted node self.graph.add_edge( source_id=source_node_id, target_id=node_id_to_extract, @@ -883,33 +1106,104 @@ class Scanner: raw_data={'context': f'Extracted from large entity {large_entity_id}'} ) + # FIXED: Only queue for further scanning if it's a domain/IP that can be scanned is_ip = _is_valid_ip(node_id_to_extract) + is_domain = _is_valid_domain(node_id_to_extract) - large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', []) - discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None) - current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0 - - eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) - for provider in eligible_providers: - provider_name = provider.get_name() - priority = self._get_priority(provider_name) - self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth))) - self.total_tasks_ever_enqueued += 1 - - if self.status != ScanStatus.RUNNING: - self.status = ScanStatus.RUNNING - self._update_session_state() + # Only queue valid domains and IPs for further processing + # Don't queue CA nodes, ISP nodes, etc. as they can't be scanned + if is_domain or is_ip: + large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', []) + discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None) + current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0 - if not self.scan_thread or not self.scan_thread.is_alive(): - self.scan_thread = threading.Thread( - target=self._execute_scan, - args=(self.current_target, self.max_depth), - daemon=True - ) - self.scan_thread.start() + eligible_providers = self._get_eligible_providers(node_id_to_extract, is_ip, False) + for provider in eligible_providers: + provider_name = provider.get_name() + priority = self._get_priority(provider_name) + self.task_queue.put((time.time(), priority, (provider_name, node_id_to_extract, current_depth))) + self.total_tasks_ever_enqueued += 1 + + if self.status != ScanStatus.RUNNING: + self.status = ScanStatus.RUNNING + self._update_session_state() + + if not self.scan_thread or not self.scan_thread.is_alive(): + self.scan_thread = threading.Thread( + target=self._execute_scan, + args=(self.current_target, self.max_depth), + daemon=True + ) + self.scan_thread.start() + else: + # For non-scannable nodes (CA, ISP, etc.), just log that they were extracted + self.logger.logger.info(f"Extracted non-scannable node {node_id_to_extract} of type {self.graph.graph.nodes[node_id_to_extract].get('type', 'unknown')}") return True - + + def _determine_extracted_node_type(self, node_id: str, large_entity_id: str) -> NodeType: + """ + FIXED: Determine the correct node type for a node being extracted from a large entity. + Uses multiple strategies to ensure accurate type detection. + """ + from utils.helpers import _is_valid_ip, _is_valid_domain + + # Strategy 1: Check if node already exists in graph with a type + if self.graph.has_node(node_id): + existing_type = self.graph.nodes[node_id].get('type') + if existing_type: + try: + return NodeType(existing_type) + except ValueError: + pass + + # Strategy 2: Look for existing relationships to this node to infer type + for source, target, edge_data in self.graph.edges(data=True): + if target == node_id: + rel_type = edge_data.get('relationship_type', '') + provider = edge_data.get('source_provider', '') + + # CA nodes from certificate issuer relationships + if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer': + return NodeType.CA + + # ISP nodes from Shodan + if provider == 'shodan' and rel_type == 'shodan_isp': + return NodeType.ISP + + # Correlation objects + if rel_type.startswith('corr_'): + return NodeType.CORRELATION_OBJECT + + if source == node_id: + rel_type = edge_data.get('relationship_type', '') + provider = edge_data.get('source_provider', '') + + # Source nodes in cert issuer relationships are CAs + if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer': + return NodeType.CA + + # Strategy 3: Format-based detection (fallback) + if _is_valid_ip(node_id): + return NodeType.IP + elif _is_valid_domain(node_id): + return NodeType.DOMAIN + + # Strategy 4: Check large entity context + if self.graph.has_node(large_entity_id): + large_entity_data = self.graph.nodes[large_entity_id] + attributes = large_entity_data.get('attributes', []) + + node_type_attr = next((attr for attr in attributes if attr.get('name') == 'node_type'), None) + if node_type_attr: + entity_node_type = node_type_attr.get('value', 'domain') + if entity_node_type == 'ip': + return NodeType.IP + else: + return NodeType.DOMAIN + + # Final fallback + return NodeType.DOMAIN def _update_session_state(self) -> None: """ Update the scanner state in Redis for GUI updates. diff --git a/providers/shodan_provider.py b/providers/shodan_provider.py index 6d16009..05dfc6c 100644 --- a/providers/shodan_provider.py +++ b/providers/shodan_provider.py @@ -27,26 +27,53 @@ class ShodanProvider(BaseProvider): ) self.base_url = "https://api.shodan.io" self.api_key = self.config.get_api_key('shodan') - self._is_active = self._check_api_connection() + + # FIXED: Don't fail initialization on connection issues - defer to actual usage + self._connection_tested = False + self._connection_works = False # Initialize cache directory self.cache_dir = Path('cache') / 'shodan' self.cache_dir.mkdir(parents=True, exist_ok=True) def _check_api_connection(self) -> bool: - """Checks if the Shodan API is reachable.""" + """ + FIXED: Lazy connection checking - only test when actually needed. + Don't block provider initialization on network issues. + """ + if self._connection_tested: + return self._connection_works + if not self.api_key: + self._connection_tested = True + self._connection_works = False return False + try: + print(f"Testing Shodan API connection with key: {self.api_key[:8]}...") response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5) - self.logger.logger.debug("Shodan is reacheable") - return response.status_code == 200 - except requests.exceptions.RequestException: - return False + self._connection_works = response.status_code == 200 + print(f"Shodan API test result: {response.status_code} - {'Success' if self._connection_works else 'Failed'}") + except requests.exceptions.RequestException as e: + print(f"Shodan API connection test failed: {e}") + self._connection_works = False + finally: + self._connection_tested = True + + return self._connection_works def is_available(self) -> bool: - """Check if Shodan provider is available (has valid API key in this session).""" - return self._is_active and self.api_key is not None and len(self.api_key.strip()) > 0 + """ + FIXED: Check if Shodan provider is available based on API key presence. + Don't require successful connection test during initialization. + """ + has_api_key = self.api_key is not None and len(self.api_key.strip()) > 0 + + if not has_api_key: + return False + + # FIXED: Only test connection on first actual usage, not during initialization + return True def get_name(self) -> str: """Return the provider name.""" @@ -117,6 +144,7 @@ class ShodanProvider(BaseProvider): def query_ip(self, ip: str) -> ProviderResult: """ Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data. + FIXED: Proper 404 handling to prevent unnecessary retries. Args: ip: IP address to investigate (IPv4 or IPv6) @@ -127,7 +155,12 @@ class ShodanProvider(BaseProvider): Raises: Exception: For temporary failures that should be retried (timeouts, 502/503 errors, connection issues) """ - if not _is_valid_ip(ip) or not self.is_available(): + if not _is_valid_ip(ip): + return ProviderResult() + + # Test connection only when actually making requests + if not self._check_api_connection(): + print(f"Shodan API not available for {ip} - API key: {'present' if self.api_key else 'missing'}") return ProviderResult() # Normalize IP address for consistent processing @@ -151,26 +184,40 @@ class ShodanProvider(BaseProvider): response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip) if not response: - # Connection failed - use stale cache if available, otherwise retry + self.logger.logger.warning(f"Shodan API unreachable for {normalized_ip} - network failure") if cache_status == "stale": - self.logger.logger.info(f"Using stale cache for {normalized_ip} due to connection failure") + self.logger.logger.info(f"Using stale cache for {normalized_ip} due to network failure") return self._load_from_cache(cache_file) else: - raise requests.exceptions.RequestException("No response from Shodan API - should retry") + # FIXED: Treat network failures as "no information" rather than retryable errors + self.logger.logger.info(f"No Shodan data available for {normalized_ip} due to network failure") + result = ProviderResult() # Empty result + network_failure_data = {'shodan_status': 'network_unreachable', 'error': 'API unreachable'} + self._save_to_cache(cache_file, result, network_failure_data) + return result + # FIXED: Handle different status codes more precisely if response.status_code == 200: self.logger.logger.debug(f"Shodan returned data for {normalized_ip}") - data = response.json() - result = self._process_shodan_data(normalized_ip, data) - self._save_to_cache(cache_file, result, data) - return result + try: + data = response.json() + result = self._process_shodan_data(normalized_ip, data) + self._save_to_cache(cache_file, result, data) + return result + except json.JSONDecodeError as e: + self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}: {e}") + if cache_status == "stale": + return self._load_from_cache(cache_file) + else: + raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry") elif response.status_code == 404: - # 404 = "no information available" - successful but empty result, don't retry + # FIXED: 404 = "no information available" - successful but empty result, don't retry self.logger.logger.debug(f"Shodan has no information for {normalized_ip} (404)") result = ProviderResult() # Empty but successful result # Cache the empty result to avoid repeated queries - self._save_to_cache(cache_file, result, {'shodan_status': 'no_information', 'status_code': 404}) + empty_data = {'shodan_status': 'no_information', 'status_code': 404} + self._save_to_cache(cache_file, result, empty_data) return result elif response.status_code in [401, 403]: @@ -178,7 +225,7 @@ class ShodanProvider(BaseProvider): self.logger.logger.error(f"Shodan API authentication failed for {normalized_ip} (HTTP {response.status_code})") return ProviderResult() # Empty result, don't retry - elif response.status_code in [429]: + elif response.status_code == 429: # Rate limiting - should be handled by rate limiter, but if we get here, retry self.logger.logger.warning(f"Shodan API rate limited for {normalized_ip} (HTTP {response.status_code})") if cache_status == "stale": @@ -197,13 +244,12 @@ class ShodanProvider(BaseProvider): raise requests.exceptions.RequestException(f"Shodan API server error (HTTP {response.status_code}) - should retry") else: - # Other HTTP error codes - treat as temporary failures - self.logger.logger.warning(f"Shodan API returned unexpected status {response.status_code} for {normalized_ip}") - if cache_status == "stale": - self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected API error") - return self._load_from_cache(cache_file) - else: - raise requests.exceptions.RequestException(f"Shodan API error (HTTP {response.status_code}) - should retry") + # FIXED: Other HTTP status codes - treat as no information available, don't retry + self.logger.logger.info(f"Shodan returned status {response.status_code} for {normalized_ip} - treating as no information") + result = ProviderResult() # Empty result + no_info_data = {'shodan_status': 'no_information', 'status_code': response.status_code} + self._save_to_cache(cache_file, result, no_info_data) + return result except requests.exceptions.Timeout: # Timeout errors - should be retried @@ -223,17 +269,8 @@ class ShodanProvider(BaseProvider): else: raise # Re-raise connection error for retry - except requests.exceptions.RequestException: - # Other request exceptions - should be retried - self.logger.logger.warning(f"Shodan API request exception for {normalized_ip}") - if cache_status == "stale": - self.logger.logger.info(f"Using stale cache for {normalized_ip} due to request exception") - return self._load_from_cache(cache_file) - else: - raise # Re-raise request exception for retry - except json.JSONDecodeError: - # JSON parsing error on 200 response - treat as temporary failure + # JSON parsing error - treat as temporary failure self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}") if cache_status == "stale": self.logger.logger.info(f"Using stale cache for {normalized_ip} due to JSON parsing error") @@ -241,14 +278,16 @@ class ShodanProvider(BaseProvider): else: raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry") + # FIXED: Remove the generic RequestException handler that was causing 404s to retry + # Now only specific exceptions that should be retried are re-raised + except Exception as e: - # Unexpected exceptions - log and treat as temporary failures - self.logger.logger.error(f"Unexpected exception in Shodan query for {normalized_ip}: {e}") - if cache_status == "stale": - self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected exception") - return self._load_from_cache(cache_file) - else: - raise requests.exceptions.RequestException(f"Unexpected error in Shodan query: {e}") from e + # FIXED: Unexpected exceptions - log but treat as no information available, don't retry + self.logger.logger.warning(f"Unexpected exception in Shodan query for {normalized_ip}: {e}") + result = ProviderResult() # Empty result + error_data = {'shodan_status': 'error', 'error': str(e)} + self._save_to_cache(cache_file, result, error_data) + return result def _load_from_cache(self, cache_file_path: Path) -> ProviderResult: """Load processed Shodan data from a cache file."""