bug fixes, improvements
This commit is contained in:
parent
4c48917993
commit
95cebbf935
414
core/scanner.py
414
core/scanner.py
@ -189,10 +189,14 @@ class Scanner:
|
|||||||
"""Initialize all available providers based on session configuration."""
|
"""Initialize all available providers based on session configuration."""
|
||||||
self.providers = []
|
self.providers = []
|
||||||
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
provider_dir = os.path.join(os.path.dirname(__file__), '..', 'providers')
|
||||||
|
|
||||||
|
print(f"=== INITIALIZING PROVIDERS FROM {provider_dir} ===")
|
||||||
|
|
||||||
for filename in os.listdir(provider_dir):
|
for filename in os.listdir(provider_dir):
|
||||||
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
if filename.endswith('_provider.py') and not filename.startswith('base'):
|
||||||
module_name = f"providers.{filename[:-3]}"
|
module_name = f"providers.{filename[:-3]}"
|
||||||
try:
|
try:
|
||||||
|
print(f"Loading provider module: {module_name}")
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
for attribute_name in dir(module):
|
for attribute_name in dir(module):
|
||||||
attribute = getattr(module, attribute_name)
|
attribute = getattr(module, attribute_name)
|
||||||
@ -202,15 +206,41 @@ class Scanner:
|
|||||||
provider = provider_class(name=attribute_name, session_config=self.config)
|
provider = provider_class(name=attribute_name, session_config=self.config)
|
||||||
provider_name = provider.get_name()
|
provider_name = provider.get_name()
|
||||||
|
|
||||||
|
print(f" Provider: {provider_name}")
|
||||||
|
print(f" Class: {provider_class.__name__}")
|
||||||
|
print(f" Config enabled: {self.config.is_provider_enabled(provider_name)}")
|
||||||
|
print(f" Requires API key: {provider.requires_api_key()}")
|
||||||
|
|
||||||
|
if provider.requires_api_key():
|
||||||
|
api_key = self.config.get_api_key(provider_name)
|
||||||
|
print(f" API key present: {'Yes' if api_key else 'No'}")
|
||||||
|
if api_key:
|
||||||
|
print(f" API key preview: {api_key[:8]}...")
|
||||||
|
|
||||||
if self.config.is_provider_enabled(provider_name):
|
if self.config.is_provider_enabled(provider_name):
|
||||||
if provider.is_available():
|
is_available = provider.is_available()
|
||||||
|
print(f" Available: {is_available}")
|
||||||
|
|
||||||
|
if is_available:
|
||||||
provider.set_stop_event(self.stop_event)
|
provider.set_stop_event(self.stop_event)
|
||||||
if isinstance(provider, CorrelationProvider):
|
if isinstance(provider, CorrelationProvider):
|
||||||
provider.set_graph_manager(self.graph)
|
provider.set_graph_manager(self.graph)
|
||||||
self.providers.append(provider)
|
self.providers.append(provider)
|
||||||
|
print(f" ✓ Added to scanner")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Not available - skipped")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Disabled in config - skipped")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f" ERROR loading {module_name}: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print(f"=== PROVIDER INITIALIZATION COMPLETE ===")
|
||||||
|
print(f"Active providers: {[p.get_name() for p in self.providers]}")
|
||||||
|
print(f"Provider count: {len(self.providers)}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
def _status_logger_thread(self):
|
def _status_logger_thread(self):
|
||||||
"""Periodically prints a clean, formatted scan status to the terminal."""
|
"""Periodically prints a clean, formatted scan status to the terminal."""
|
||||||
HEADER = "\033[95m"
|
HEADER = "\033[95m"
|
||||||
@ -424,6 +454,9 @@ class Scanner:
|
|||||||
|
|
||||||
is_ip = _is_valid_ip(target)
|
is_ip = _is_valid_ip(target)
|
||||||
initial_providers = self._get_eligible_providers(target, is_ip, False)
|
initial_providers = self._get_eligible_providers(target, is_ip, False)
|
||||||
|
# FIXED: Filter out correlation provider from initial providers
|
||||||
|
initial_providers = [p for p in initial_providers if not isinstance(p, CorrelationProvider)]
|
||||||
|
|
||||||
for provider in initial_providers:
|
for provider in initial_providers:
|
||||||
provider_name = provider.get_name()
|
provider_name = provider.get_name()
|
||||||
priority = self._get_priority(provider_name)
|
priority = self._get_priority(provider_name)
|
||||||
@ -443,6 +476,8 @@ class Scanner:
|
|||||||
consecutive_empty_iterations = 0
|
consecutive_empty_iterations = 0
|
||||||
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
|
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
|
||||||
|
|
||||||
|
# PHASE 1: Run all non-correlation providers
|
||||||
|
print(f"\n=== PHASE 1: Running non-correlation providers ===")
|
||||||
while not self._is_stop_requested():
|
while not self._is_stop_requested():
|
||||||
queue_empty = self.task_queue.empty()
|
queue_empty = self.task_queue.empty()
|
||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
@ -451,23 +486,25 @@ class Scanner:
|
|||||||
if queue_empty and no_active_processing:
|
if queue_empty and no_active_processing:
|
||||||
consecutive_empty_iterations += 1
|
consecutive_empty_iterations += 1
|
||||||
if consecutive_empty_iterations >= max_empty_iterations:
|
if consecutive_empty_iterations >= max_empty_iterations:
|
||||||
break # Scan is complete
|
break # Phase 1 complete
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
consecutive_empty_iterations = 0
|
consecutive_empty_iterations = 0
|
||||||
|
|
||||||
# FIXED: Safe task retrieval without race conditions
|
# Process tasks (same logic as before, but correlations are filtered out)
|
||||||
try:
|
try:
|
||||||
# Use timeout to avoid blocking indefinitely
|
|
||||||
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
|
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
|
||||||
|
|
||||||
# FIXED: Check if task is ready to run
|
# Skip correlation tasks during Phase 1
|
||||||
|
if provider_name == 'correlation':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if task is ready to run
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if run_at > current_time:
|
if run_at > current_time:
|
||||||
# Task is not ready yet, re-queue it and continue
|
|
||||||
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
|
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
|
||||||
time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time
|
time.sleep(min(0.5, run_at - current_time))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except: # Queue is empty or timeout occurred
|
except: # Queue is empty or timeout occurred
|
||||||
@ -476,34 +513,32 @@ class Scanner:
|
|||||||
|
|
||||||
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
|
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
|
||||||
|
|
||||||
# FIXED: Include depth in processed tasks to avoid incorrect skipping
|
# Skip if already processed
|
||||||
task_tuple = (provider_name, target_item, depth)
|
task_tuple = (provider_name, target_item, depth)
|
||||||
if task_tuple in processed_tasks:
|
if task_tuple in processed_tasks:
|
||||||
self.tasks_skipped += 1
|
self.tasks_skipped += 1
|
||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXED: Proper depth checking
|
# Skip if depth exceeded
|
||||||
if depth > max_depth:
|
if depth > max_depth:
|
||||||
self.tasks_skipped += 1
|
self.tasks_skipped += 1
|
||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXED: Rate limiting with proper time-based deferral
|
# Rate limiting with proper time-based deferral
|
||||||
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
|
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
|
||||||
defer_until = time.time() + 60 # Defer for 60 seconds
|
defer_until = time.time() + 60
|
||||||
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
|
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
|
||||||
self.tasks_re_enqueued += 1
|
self.tasks_re_enqueued += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXED: Thread-safe processing state management
|
# Thread-safe processing state management
|
||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
if self._is_stop_requested():
|
if self._is_stop_requested():
|
||||||
break
|
break
|
||||||
# Use provider+target (without depth) for duplicate processing check
|
|
||||||
processing_key = (provider_name, target_item)
|
processing_key = (provider_name, target_item)
|
||||||
if processing_key in self.currently_processing:
|
if processing_key in self.currently_processing:
|
||||||
# Already processing this provider+target combination, skip
|
|
||||||
self.tasks_skipped += 1
|
self.tasks_skipped += 1
|
||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
continue
|
continue
|
||||||
@ -519,21 +554,19 @@ class Scanner:
|
|||||||
|
|
||||||
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
||||||
|
|
||||||
if provider:
|
if provider and not isinstance(provider, CorrelationProvider):
|
||||||
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
|
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
|
||||||
|
|
||||||
if self._is_stop_requested():
|
if self._is_stop_requested():
|
||||||
break
|
break
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
# FIXED: Use depth-aware retry key
|
|
||||||
retry_key = (provider_name, target_item, depth)
|
retry_key = (provider_name, target_item, depth)
|
||||||
self.target_retries[retry_key] += 1
|
self.target_retries[retry_key] += 1
|
||||||
|
|
||||||
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
|
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
|
||||||
# FIXED: Exponential backoff with jitter for retries
|
|
||||||
retry_count = self.target_retries[retry_key]
|
retry_count = self.target_retries[retry_key]
|
||||||
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes
|
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1))
|
||||||
retry_at = time.time() + backoff_delay
|
retry_at = time.time() + backoff_delay
|
||||||
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
|
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
|
||||||
self.tasks_re_enqueued += 1
|
self.tasks_re_enqueued += 1
|
||||||
@ -545,21 +578,21 @@ class Scanner:
|
|||||||
processed_tasks.add(task_tuple)
|
processed_tasks.add(task_tuple)
|
||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
|
|
||||||
# FIXED: Enqueue new targets with proper depth tracking
|
# Enqueue new targets with proper depth tracking
|
||||||
if not self._is_stop_requested():
|
if not self._is_stop_requested():
|
||||||
for new_target in new_targets:
|
for new_target in new_targets:
|
||||||
is_ip_new = _is_valid_ip(new_target)
|
is_ip_new = _is_valid_ip(new_target)
|
||||||
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
|
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
|
||||||
|
# FIXED: Filter out correlation providers in Phase 1
|
||||||
|
eligible_providers_new = [p for p in eligible_providers_new if not isinstance(p, CorrelationProvider)]
|
||||||
|
|
||||||
for p_new in eligible_providers_new:
|
for p_new in eligible_providers_new:
|
||||||
p_name_new = p_new.get_name()
|
p_name_new = p_new.get_name()
|
||||||
new_depth = depth + 1 # Always increment depth for discovered targets
|
new_depth = depth + 1
|
||||||
new_task_tuple = (p_name_new, new_target, new_depth)
|
new_task_tuple = (p_name_new, new_target, new_depth)
|
||||||
|
|
||||||
# FIXED: Don't re-enqueue already processed tasks
|
|
||||||
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
|
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
|
||||||
new_priority = self._get_priority(p_name_new)
|
new_priority = self._get_priority(p_name_new)
|
||||||
# Enqueue new tasks to run immediately
|
|
||||||
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
|
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
|
||||||
self.total_tasks_ever_enqueued += 1
|
self.total_tasks_ever_enqueued += 1
|
||||||
else:
|
else:
|
||||||
@ -568,22 +601,25 @@ class Scanner:
|
|||||||
self.indicators_completed += 1
|
self.indicators_completed += 1
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# FIXED: Always clean up processing state
|
|
||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
processing_key = (provider_name, target_item)
|
processing_key = (provider_name, target_item)
|
||||||
self.currently_processing.discard(processing_key)
|
self.currently_processing.discard(processing_key)
|
||||||
|
|
||||||
|
# PHASE 2: Run correlations on all discovered nodes
|
||||||
|
if not self._is_stop_requested():
|
||||||
|
print(f"\n=== PHASE 2: Running correlation analysis ===")
|
||||||
|
self._run_correlation_phase(max_depth, processed_tasks)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
self.status = ScanStatus.FAILED
|
self.status = ScanStatus.FAILED
|
||||||
self.logger.logger.error(f"Scan failed: {e}")
|
self.logger.logger.error(f"Scan failed: {e}")
|
||||||
finally:
|
finally:
|
||||||
# FIXED: Comprehensive cleanup
|
# Comprehensive cleanup (same as before)
|
||||||
with self.processing_lock:
|
with self.processing_lock:
|
||||||
self.currently_processing.clear()
|
self.currently_processing.clear()
|
||||||
self.currently_processing_display = []
|
self.currently_processing_display = []
|
||||||
|
|
||||||
# FIXED: Clear any remaining tasks from queue to prevent memory leaks
|
|
||||||
while not self.task_queue.empty():
|
while not self.task_queue.empty():
|
||||||
try:
|
try:
|
||||||
self.task_queue.get_nowait()
|
self.task_queue.get_nowait()
|
||||||
@ -599,7 +635,7 @@ class Scanner:
|
|||||||
|
|
||||||
self.status_logger_stop_event.set()
|
self.status_logger_stop_event.set()
|
||||||
if self.status_logger_thread and self.status_logger_thread.is_alive():
|
if self.status_logger_thread and self.status_logger_thread.is_alive():
|
||||||
self.status_logger_thread.join(timeout=2.0) # Don't wait forever
|
self.status_logger_thread.join(timeout=2.0)
|
||||||
|
|
||||||
self._update_session_state()
|
self._update_session_state()
|
||||||
self.logger.log_scan_complete()
|
self.logger.log_scan_complete()
|
||||||
@ -612,10 +648,120 @@ class Scanner:
|
|||||||
finally:
|
finally:
|
||||||
self.executor = None
|
self.executor = None
|
||||||
|
|
||||||
|
def _run_correlation_phase(self, max_depth: int, processed_tasks: set) -> None:
|
||||||
|
"""
|
||||||
|
PHASE 2: Run correlation analysis on all discovered nodes.
|
||||||
|
This ensures correlations run after all other providers have completed.
|
||||||
|
"""
|
||||||
|
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
|
||||||
|
if not correlation_provider:
|
||||||
|
print("No correlation provider found - skipping correlation phase")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all nodes from the graph for correlation analysis
|
||||||
|
all_nodes = list(self.graph.graph.nodes())
|
||||||
|
correlation_tasks = []
|
||||||
|
|
||||||
|
print(f"Enqueueing correlation tasks for {len(all_nodes)} nodes")
|
||||||
|
|
||||||
|
for node_id in all_nodes:
|
||||||
|
if self._is_stop_requested():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Determine appropriate depth for correlation (use 0 for simplicity)
|
||||||
|
correlation_depth = 0
|
||||||
|
task_tuple = ('correlation', node_id, correlation_depth)
|
||||||
|
|
||||||
|
# Don't re-process already processed correlation tasks
|
||||||
|
if task_tuple not in processed_tasks:
|
||||||
|
priority = self._get_priority('correlation')
|
||||||
|
self.task_queue.put((time.time(), priority, ('correlation', node_id, correlation_depth)))
|
||||||
|
correlation_tasks.append(task_tuple)
|
||||||
|
self.total_tasks_ever_enqueued += 1
|
||||||
|
|
||||||
|
print(f"Enqueued {len(correlation_tasks)} correlation tasks")
|
||||||
|
|
||||||
|
# Process correlation tasks
|
||||||
|
consecutive_empty_iterations = 0
|
||||||
|
max_empty_iterations = 20 # Shorter timeout for correlation phase
|
||||||
|
|
||||||
|
while not self._is_stop_requested() and correlation_tasks:
|
||||||
|
queue_empty = self.task_queue.empty()
|
||||||
|
with self.processing_lock:
|
||||||
|
no_active_processing = len(self.currently_processing) == 0
|
||||||
|
|
||||||
|
if queue_empty and no_active_processing:
|
||||||
|
consecutive_empty_iterations += 1
|
||||||
|
if consecutive_empty_iterations >= max_empty_iterations:
|
||||||
|
break
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
consecutive_empty_iterations = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
|
||||||
|
|
||||||
|
# Only process correlation tasks in this phase
|
||||||
|
if provider_name != 'correlation':
|
||||||
|
continue
|
||||||
|
|
||||||
|
except:
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_tuple = (provider_name, target_item, depth)
|
||||||
|
|
||||||
|
# Skip if already processed
|
||||||
|
if task_tuple in processed_tasks:
|
||||||
|
self.tasks_skipped += 1
|
||||||
|
self.indicators_completed += 1
|
||||||
|
if task_tuple in correlation_tasks:
|
||||||
|
correlation_tasks.remove(task_tuple)
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.processing_lock:
|
||||||
|
if self._is_stop_requested():
|
||||||
|
break
|
||||||
|
processing_key = (provider_name, target_item)
|
||||||
|
if processing_key in self.currently_processing:
|
||||||
|
self.tasks_skipped += 1
|
||||||
|
self.indicators_completed += 1
|
||||||
|
continue
|
||||||
|
self.currently_processing.add(processing_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_indicator = target_item
|
||||||
|
self._update_session_state()
|
||||||
|
|
||||||
|
if self._is_stop_requested():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process correlation task
|
||||||
|
new_targets, _, success = self._process_provider_task(correlation_provider, target_item, depth)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
processed_tasks.add(task_tuple)
|
||||||
|
self.indicators_completed += 1
|
||||||
|
if task_tuple in correlation_tasks:
|
||||||
|
correlation_tasks.remove(task_tuple)
|
||||||
|
else:
|
||||||
|
# For correlations, don't retry - just mark as completed
|
||||||
|
self.indicators_completed += 1
|
||||||
|
if task_tuple in correlation_tasks:
|
||||||
|
correlation_tasks.remove(task_tuple)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with self.processing_lock:
|
||||||
|
processing_key = (provider_name, target_item)
|
||||||
|
self.currently_processing.discard(processing_key)
|
||||||
|
|
||||||
|
print(f"Correlation phase complete. Remaining tasks: {len(correlation_tasks)}")
|
||||||
|
|
||||||
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
||||||
"""
|
"""
|
||||||
Manages the entire process for a given target and provider.
|
Manages the entire process for a given target and provider.
|
||||||
It uses the "worker" function to get the data and then manages the consequences.
|
FIXED: Don't enqueue correlation tasks during normal processing.
|
||||||
"""
|
"""
|
||||||
if self._is_stop_requested():
|
if self._is_stop_requested():
|
||||||
return set(), set(), False
|
return set(), set(), False
|
||||||
@ -644,14 +790,6 @@ class Scanner:
|
|||||||
else:
|
else:
|
||||||
new_targets.update(discovered)
|
new_targets.update(discovered)
|
||||||
|
|
||||||
# After processing a provider, queue a correlation task for the target
|
|
||||||
correlation_provider = next((p for p in self.providers if isinstance(p, CorrelationProvider)), None)
|
|
||||||
if correlation_provider and not isinstance(provider, CorrelationProvider):
|
|
||||||
priority = self._get_priority(correlation_provider.get_name())
|
|
||||||
self.task_queue.put((time.time(), priority, (correlation_provider.get_name(), target, depth)))
|
|
||||||
# FIXED: Increment total tasks when a correlation task is enqueued
|
|
||||||
self.total_tasks_ever_enqueued += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
provider_successful = False
|
provider_successful = False
|
||||||
self._log_provider_error(target, provider.get_name(), str(e))
|
self._log_provider_error(target, provider.get_name(), str(e))
|
||||||
@ -690,7 +828,7 @@ class Scanner:
|
|||||||
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
|
provider_result: ProviderResult, current_depth: int) -> Tuple[Set[str], bool]:
|
||||||
"""
|
"""
|
||||||
Process a unified ProviderResult object to update the graph.
|
Process a unified ProviderResult object to update the graph.
|
||||||
VERIFIED: Proper ISP and CA node type assignment.
|
FIXED: Ensure CA and ISP relationships are created even when large entities are formed.
|
||||||
"""
|
"""
|
||||||
provider_name = provider.get_name()
|
provider_name = provider.get_name()
|
||||||
discovered_targets = set()
|
discovered_targets = set()
|
||||||
@ -698,12 +836,70 @@ class Scanner:
|
|||||||
if self._is_stop_requested():
|
if self._is_stop_requested():
|
||||||
return discovered_targets, False
|
return discovered_targets, False
|
||||||
|
|
||||||
# Check if this should be a large entity
|
# Check if this should be a large entity (only counting domain/IP relationships)
|
||||||
if provider_result.get_relationship_count() > self.config.large_entity_threshold:
|
eligible_relationship_count = 0
|
||||||
|
for rel in provider_result.relationships:
|
||||||
|
# Only count relationships that would go into large entities
|
||||||
|
if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer':
|
||||||
|
continue # Don't count CA relationships
|
||||||
|
if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp':
|
||||||
|
continue # Don't count ISP relationships
|
||||||
|
if rel.relationship_type.startswith('corr_'):
|
||||||
|
continue # Don't count correlation relationships
|
||||||
|
|
||||||
|
# Only count domain/IP targets
|
||||||
|
if _is_valid_domain(rel.target_node) or _is_valid_ip(rel.target_node):
|
||||||
|
eligible_relationship_count += 1
|
||||||
|
|
||||||
|
if eligible_relationship_count > self.config.large_entity_threshold:
|
||||||
|
# Create large entity but ALSO process special relationships
|
||||||
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
|
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
|
||||||
|
|
||||||
|
# FIXED: Still process CA, ISP, and correlation relationships directly on the graph
|
||||||
|
for relationship in provider_result.relationships:
|
||||||
|
if self._is_stop_requested():
|
||||||
|
break
|
||||||
|
|
||||||
|
source_node = relationship.source_node
|
||||||
|
target_node = relationship.target_node
|
||||||
|
|
||||||
|
# Process special relationship types that should appear directly on graph
|
||||||
|
should_create_direct_relationship = False
|
||||||
|
target_type = None
|
||||||
|
|
||||||
|
if provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
|
||||||
|
target_type = NodeType.CA
|
||||||
|
should_create_direct_relationship = True
|
||||||
|
elif provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
|
||||||
|
target_type = NodeType.ISP
|
||||||
|
should_create_direct_relationship = True
|
||||||
|
elif relationship.relationship_type.startswith('corr_'):
|
||||||
|
target_type = NodeType.CORRELATION_OBJECT
|
||||||
|
should_create_direct_relationship = True
|
||||||
|
|
||||||
|
if should_create_direct_relationship:
|
||||||
|
# Create source and target nodes
|
||||||
|
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
|
||||||
|
self.graph.add_node(source_node, source_type)
|
||||||
|
self.graph.add_node(target_node, target_type)
|
||||||
|
|
||||||
|
# Add the relationship edge
|
||||||
|
self.graph.add_edge(
|
||||||
|
source_node, target_node,
|
||||||
|
relationship.relationship_type,
|
||||||
|
relationship.confidence,
|
||||||
|
provider_name,
|
||||||
|
relationship.raw_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to discovered targets if it's a valid target for further processing
|
||||||
|
max_depth_reached = current_depth >= self.max_depth
|
||||||
|
if not max_depth_reached and (_is_valid_domain(target_node) or _is_valid_ip(target_node)):
|
||||||
|
discovered_targets.add(target_node)
|
||||||
|
|
||||||
return members, True
|
return members, True
|
||||||
|
|
||||||
# Process relationships and create nodes with proper types
|
# Normal processing (existing logic) when not creating large entity
|
||||||
for i, relationship in enumerate(provider_result.relationships):
|
for i, relationship in enumerate(provider_result.relationships):
|
||||||
if i % 5 == 0 and self._is_stop_requested():
|
if i % 5 == 0 and self._is_stop_requested():
|
||||||
break
|
break
|
||||||
@ -711,14 +907,14 @@ class Scanner:
|
|||||||
source_node = relationship.source_node
|
source_node = relationship.source_node
|
||||||
target_node = relationship.target_node
|
target_node = relationship.target_node
|
||||||
|
|
||||||
# VERIFIED: Determine source node type
|
# Determine source node type
|
||||||
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
|
source_type = NodeType.IP if _is_valid_ip(source_node) else NodeType.DOMAIN
|
||||||
|
|
||||||
# VERIFIED: Determine target node type based on provider and relationship
|
# Determine target node type based on provider and relationship
|
||||||
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
|
if provider_name == 'shodan' and relationship.relationship_type == 'shodan_isp':
|
||||||
target_type = NodeType.ISP # ISP node for Shodan organization data
|
target_type = NodeType.ISP
|
||||||
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
|
elif provider_name == 'crtsh' and relationship.relationship_type == 'crtsh_cert_issuer':
|
||||||
target_type = NodeType.CA # CA node for certificate issuers
|
target_type = NodeType.CA
|
||||||
elif provider_name == 'correlation':
|
elif provider_name == 'correlation':
|
||||||
target_type = NodeType.CORRELATION_OBJECT
|
target_type = NodeType.CORRELATION_OBJECT
|
||||||
elif _is_valid_ip(target_node):
|
elif _is_valid_ip(target_node):
|
||||||
@ -733,7 +929,6 @@ class Scanner:
|
|||||||
self.graph.add_node(source_node, source_type)
|
self.graph.add_node(source_node, source_type)
|
||||||
self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached})
|
self.graph.add_node(target_node, target_type, metadata={'max_depth_reached': max_depth_reached})
|
||||||
|
|
||||||
|
|
||||||
# Add the relationship edge
|
# Add the relationship edge
|
||||||
if self.graph.add_edge(
|
if self.graph.add_edge(
|
||||||
source_node, target_node,
|
source_node, target_node,
|
||||||
@ -748,7 +943,7 @@ class Scanner:
|
|||||||
if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached:
|
if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached:
|
||||||
discovered_targets.add(target_node)
|
discovered_targets.add(target_node)
|
||||||
|
|
||||||
# Process all attributes, grouping by target node
|
# Process all attributes (existing logic unchanged)
|
||||||
attributes_by_node = defaultdict(list)
|
attributes_by_node = defaultdict(list)
|
||||||
for attribute in provider_result.attributes:
|
for attribute in provider_result.attributes:
|
||||||
attr_dict = {
|
attr_dict = {
|
||||||
@ -764,11 +959,9 @@ class Scanner:
|
|||||||
# Add attributes to existing nodes OR create new nodes if they don't exist
|
# Add attributes to existing nodes OR create new nodes if they don't exist
|
||||||
for node_id, node_attributes_list in attributes_by_node.items():
|
for node_id, node_attributes_list in attributes_by_node.items():
|
||||||
if not self.graph.graph.has_node(node_id):
|
if not self.graph.graph.has_node(node_id):
|
||||||
# If the node doesn't exist, create it with a default type
|
|
||||||
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
|
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
|
||||||
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
|
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
|
||||||
else:
|
else:
|
||||||
# If the node already exists, just add the attributes
|
|
||||||
node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain')
|
node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain')
|
||||||
self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list)
|
self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list)
|
||||||
|
|
||||||
@ -778,25 +971,50 @@ class Scanner:
|
|||||||
provider_result: ProviderResult, current_depth: int) -> Set[str]:
|
provider_result: ProviderResult, current_depth: int) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
Create a large entity node from a ProviderResult.
|
Create a large entity node from a ProviderResult.
|
||||||
|
FIXED: Only include domain/IP nodes in large entities, exclude CA and other special node types.
|
||||||
"""
|
"""
|
||||||
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
|
entity_id = f"large_entity_{provider_name}_{hash(source) & 0x7FFFFFFF}"
|
||||||
|
|
||||||
targets = [rel.target_node for rel in provider_result.relationships]
|
# FIXED: Filter out CA, ISP, and correlation nodes from large entity inclusion
|
||||||
node_type = 'unknown'
|
eligible_targets = []
|
||||||
|
for rel in provider_result.relationships:
|
||||||
|
target_node = rel.target_node
|
||||||
|
|
||||||
if targets:
|
# Skip CA nodes (certificate issuers) - they should appear directly on graph
|
||||||
if _is_valid_domain(targets[0]):
|
if provider_name == 'crtsh' and rel.relationship_type == 'crtsh_cert_issuer':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip ISP nodes - they should appear directly on graph
|
||||||
|
if provider_name == 'shodan' and rel.relationship_type == 'shodan_isp':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip correlation objects - they should appear directly on graph
|
||||||
|
if rel.relationship_type.startswith('corr_'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only include valid domains and IPs in large entities
|
||||||
|
if _is_valid_domain(target_node) or _is_valid_ip(target_node):
|
||||||
|
eligible_targets.append(target_node)
|
||||||
|
|
||||||
|
# If no eligible targets after filtering, don't create large entity
|
||||||
|
if not eligible_targets:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
node_type = 'unknown'
|
||||||
|
if eligible_targets:
|
||||||
|
if _is_valid_domain(eligible_targets[0]):
|
||||||
node_type = 'domain'
|
node_type = 'domain'
|
||||||
elif _is_valid_ip(targets[0]):
|
elif _is_valid_ip(eligible_targets[0]):
|
||||||
node_type = 'ip'
|
node_type = 'ip'
|
||||||
|
|
||||||
for target in targets:
|
# Create individual nodes for eligible targets
|
||||||
|
for target in eligible_targets:
|
||||||
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
|
target_node_type = NodeType.DOMAIN if node_type == 'domain' else NodeType.IP
|
||||||
self.graph.add_node(target, target_node_type)
|
self.graph.add_node(target, target_node_type)
|
||||||
|
|
||||||
attributes_dict = {
|
attributes_dict = {
|
||||||
'count': len(targets),
|
'count': len(eligible_targets),
|
||||||
'nodes': targets,
|
'nodes': eligible_targets, # Only eligible domain/IP nodes
|
||||||
'node_type': node_type,
|
'node_type': node_type,
|
||||||
'source_provider': provider_name,
|
'source_provider': provider_name,
|
||||||
'discovery_depth': current_depth,
|
'discovery_depth': current_depth,
|
||||||
@ -814,18 +1032,21 @@ class Scanner:
|
|||||||
"metadata": {}
|
"metadata": {}
|
||||||
})
|
})
|
||||||
|
|
||||||
description = f'Large entity created due to {len(targets)} relationships from {provider_name}'
|
description = f'Large entity created due to {len(eligible_targets)} relationships from {provider_name}'
|
||||||
|
|
||||||
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
|
self.graph.add_node(entity_id, NodeType.LARGE_ENTITY, attributes=attributes_list, description=description)
|
||||||
|
|
||||||
if provider_result.relationships:
|
if provider_result.relationships:
|
||||||
rel_type = provider_result.relationships[0].relationship_type
|
# Use the first eligible relationship for the large entity connection
|
||||||
|
eligible_rels = [rel for rel in provider_result.relationships if rel.target_node in eligible_targets]
|
||||||
|
if eligible_rels:
|
||||||
|
rel_type = eligible_rels[0].relationship_type
|
||||||
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
|
self.graph.add_edge(source, entity_id, rel_type, 0.9, provider_name,
|
||||||
{'large_entity_info': f'Contains {len(targets)} {node_type}s'})
|
{'large_entity_info': f'Contains {len(eligible_targets)} {node_type}s'})
|
||||||
|
|
||||||
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(targets)} targets from {provider_name}")
|
self.logger.logger.warning(f"Large entity created: {entity_id} contains {len(eligible_targets)} targets from {provider_name}")
|
||||||
|
|
||||||
return set(targets)
|
return set(eligible_targets)
|
||||||
|
|
||||||
def stop_scan(self) -> bool:
|
def stop_scan(self) -> bool:
|
||||||
"""Request immediate scan termination with proper cleanup."""
|
"""Request immediate scan termination with proper cleanup."""
|
||||||
@ -857,6 +1078,7 @@ class Scanner:
|
|||||||
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
|
def extract_node_from_large_entity(self, large_entity_id: str, node_id_to_extract: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Extracts a node from a large entity and re-queues it for scanning.
|
Extracts a node from a large entity and re-queues it for scanning.
|
||||||
|
FIXED: Properly handle different node types during extraction.
|
||||||
"""
|
"""
|
||||||
if not self.graph.graph.has_node(large_entity_id):
|
if not self.graph.graph.has_node(large_entity_id):
|
||||||
return False
|
return False
|
||||||
@ -874,6 +1096,7 @@ class Scanner:
|
|||||||
if not success:
|
if not success:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Create relationship from source to extracted node
|
||||||
self.graph.add_edge(
|
self.graph.add_edge(
|
||||||
source_id=source_node_id,
|
source_id=source_node_id,
|
||||||
target_id=node_id_to_extract,
|
target_id=node_id_to_extract,
|
||||||
@ -883,8 +1106,13 @@ class Scanner:
|
|||||||
raw_data={'context': f'Extracted from large entity {large_entity_id}'}
|
raw_data={'context': f'Extracted from large entity {large_entity_id}'}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXED: Only queue for further scanning if it's a domain/IP that can be scanned
|
||||||
is_ip = _is_valid_ip(node_id_to_extract)
|
is_ip = _is_valid_ip(node_id_to_extract)
|
||||||
|
is_domain = _is_valid_domain(node_id_to_extract)
|
||||||
|
|
||||||
|
# Only queue valid domains and IPs for further processing
|
||||||
|
# Don't queue CA nodes, ISP nodes, etc. as they can't be scanned
|
||||||
|
if is_domain or is_ip:
|
||||||
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
large_entity_attributes = self.graph.graph.nodes[large_entity_id].get('attributes', [])
|
||||||
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
|
discovery_depth_attr = next((attr for attr in large_entity_attributes if attr.get('name') == 'discovery_depth'), None)
|
||||||
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
|
current_depth = discovery_depth_attr['value'] if discovery_depth_attr else 0
|
||||||
@ -907,9 +1135,75 @@ class Scanner:
|
|||||||
daemon=True
|
daemon=True
|
||||||
)
|
)
|
||||||
self.scan_thread.start()
|
self.scan_thread.start()
|
||||||
|
else:
|
||||||
|
# For non-scannable nodes (CA, ISP, etc.), just log that they were extracted
|
||||||
|
self.logger.logger.info(f"Extracted non-scannable node {node_id_to_extract} of type {self.graph.graph.nodes[node_id_to_extract].get('type', 'unknown')}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _determine_extracted_node_type(self, node_id: str, large_entity_id: str) -> NodeType:
|
||||||
|
"""
|
||||||
|
FIXED: Determine the correct node type for a node being extracted from a large entity.
|
||||||
|
Uses multiple strategies to ensure accurate type detection.
|
||||||
|
"""
|
||||||
|
from utils.helpers import _is_valid_ip, _is_valid_domain
|
||||||
|
|
||||||
|
# Strategy 1: Check if node already exists in graph with a type
|
||||||
|
if self.graph.has_node(node_id):
|
||||||
|
existing_type = self.graph.nodes[node_id].get('type')
|
||||||
|
if existing_type:
|
||||||
|
try:
|
||||||
|
return NodeType(existing_type)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Strategy 2: Look for existing relationships to this node to infer type
|
||||||
|
for source, target, edge_data in self.graph.edges(data=True):
|
||||||
|
if target == node_id:
|
||||||
|
rel_type = edge_data.get('relationship_type', '')
|
||||||
|
provider = edge_data.get('source_provider', '')
|
||||||
|
|
||||||
|
# CA nodes from certificate issuer relationships
|
||||||
|
if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer':
|
||||||
|
return NodeType.CA
|
||||||
|
|
||||||
|
# ISP nodes from Shodan
|
||||||
|
if provider == 'shodan' and rel_type == 'shodan_isp':
|
||||||
|
return NodeType.ISP
|
||||||
|
|
||||||
|
# Correlation objects
|
||||||
|
if rel_type.startswith('corr_'):
|
||||||
|
return NodeType.CORRELATION_OBJECT
|
||||||
|
|
||||||
|
if source == node_id:
|
||||||
|
rel_type = edge_data.get('relationship_type', '')
|
||||||
|
provider = edge_data.get('source_provider', '')
|
||||||
|
|
||||||
|
# Source nodes in cert issuer relationships are CAs
|
||||||
|
if provider == 'crtsh' and rel_type == 'crtsh_cert_issuer':
|
||||||
|
return NodeType.CA
|
||||||
|
|
||||||
|
# Strategy 3: Format-based detection (fallback)
|
||||||
|
if _is_valid_ip(node_id):
|
||||||
|
return NodeType.IP
|
||||||
|
elif _is_valid_domain(node_id):
|
||||||
|
return NodeType.DOMAIN
|
||||||
|
|
||||||
|
# Strategy 4: Check large entity context
|
||||||
|
if self.graph.has_node(large_entity_id):
|
||||||
|
large_entity_data = self.graph.nodes[large_entity_id]
|
||||||
|
attributes = large_entity_data.get('attributes', [])
|
||||||
|
|
||||||
|
node_type_attr = next((attr for attr in attributes if attr.get('name') == 'node_type'), None)
|
||||||
|
if node_type_attr:
|
||||||
|
entity_node_type = node_type_attr.get('value', 'domain')
|
||||||
|
if entity_node_type == 'ip':
|
||||||
|
return NodeType.IP
|
||||||
|
else:
|
||||||
|
return NodeType.DOMAIN
|
||||||
|
|
||||||
|
# Final fallback
|
||||||
|
return NodeType.DOMAIN
|
||||||
def _update_session_state(self) -> None:
|
def _update_session_state(self) -> None:
|
||||||
"""
|
"""
|
||||||
Update the scanner state in Redis for GUI updates.
|
Update the scanner state in Redis for GUI updates.
|
||||||
|
|||||||
@ -27,26 +27,53 @@ class ShodanProvider(BaseProvider):
|
|||||||
)
|
)
|
||||||
self.base_url = "https://api.shodan.io"
|
self.base_url = "https://api.shodan.io"
|
||||||
self.api_key = self.config.get_api_key('shodan')
|
self.api_key = self.config.get_api_key('shodan')
|
||||||
self._is_active = self._check_api_connection()
|
|
||||||
|
# FIXED: Don't fail initialization on connection issues - defer to actual usage
|
||||||
|
self._connection_tested = False
|
||||||
|
self._connection_works = False
|
||||||
|
|
||||||
# Initialize cache directory
|
# Initialize cache directory
|
||||||
self.cache_dir = Path('cache') / 'shodan'
|
self.cache_dir = Path('cache') / 'shodan'
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _check_api_connection(self) -> bool:
|
def _check_api_connection(self) -> bool:
|
||||||
"""Checks if the Shodan API is reachable."""
|
"""
|
||||||
|
FIXED: Lazy connection checking - only test when actually needed.
|
||||||
|
Don't block provider initialization on network issues.
|
||||||
|
"""
|
||||||
|
if self._connection_tested:
|
||||||
|
return self._connection_works
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
return False
|
self._connection_tested = True
|
||||||
try:
|
self._connection_works = False
|
||||||
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
|
|
||||||
self.logger.logger.debug("Shodan is reacheable")
|
|
||||||
return response.status_code == 200
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Testing Shodan API connection with key: {self.api_key[:8]}...")
|
||||||
|
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
|
||||||
|
self._connection_works = response.status_code == 200
|
||||||
|
print(f"Shodan API test result: {response.status_code} - {'Success' if self._connection_works else 'Failed'}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Shodan API connection test failed: {e}")
|
||||||
|
self._connection_works = False
|
||||||
|
finally:
|
||||||
|
self._connection_tested = True
|
||||||
|
|
||||||
|
return self._connection_works
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
def is_available(self) -> bool:
|
||||||
"""Check if Shodan provider is available (has valid API key in this session)."""
|
"""
|
||||||
return self._is_active and self.api_key is not None and len(self.api_key.strip()) > 0
|
FIXED: Check if Shodan provider is available based on API key presence.
|
||||||
|
Don't require successful connection test during initialization.
|
||||||
|
"""
|
||||||
|
has_api_key = self.api_key is not None and len(self.api_key.strip()) > 0
|
||||||
|
|
||||||
|
if not has_api_key:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# FIXED: Only test connection on first actual usage, not during initialization
|
||||||
|
return True
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
"""Return the provider name."""
|
"""Return the provider name."""
|
||||||
@ -117,6 +144,7 @@ class ShodanProvider(BaseProvider):
|
|||||||
def query_ip(self, ip: str) -> ProviderResult:
|
def query_ip(self, ip: str) -> ProviderResult:
|
||||||
"""
|
"""
|
||||||
Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data.
|
Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data.
|
||||||
|
FIXED: Proper 404 handling to prevent unnecessary retries.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ip: IP address to investigate (IPv4 or IPv6)
|
ip: IP address to investigate (IPv4 or IPv6)
|
||||||
@ -127,7 +155,12 @@ class ShodanProvider(BaseProvider):
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: For temporary failures that should be retried (timeouts, 502/503 errors, connection issues)
|
Exception: For temporary failures that should be retried (timeouts, 502/503 errors, connection issues)
|
||||||
"""
|
"""
|
||||||
if not _is_valid_ip(ip) or not self.is_available():
|
if not _is_valid_ip(ip):
|
||||||
|
return ProviderResult()
|
||||||
|
|
||||||
|
# Test connection only when actually making requests
|
||||||
|
if not self._check_api_connection():
|
||||||
|
print(f"Shodan API not available for {ip} - API key: {'present' if self.api_key else 'missing'}")
|
||||||
return ProviderResult()
|
return ProviderResult()
|
||||||
|
|
||||||
# Normalize IP address for consistent processing
|
# Normalize IP address for consistent processing
|
||||||
@ -151,26 +184,40 @@ class ShodanProvider(BaseProvider):
|
|||||||
response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip)
|
response = self.make_request(url, method="GET", params=params, target_indicator=normalized_ip)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
# Connection failed - use stale cache if available, otherwise retry
|
self.logger.logger.warning(f"Shodan API unreachable for {normalized_ip} - network failure")
|
||||||
if cache_status == "stale":
|
if cache_status == "stale":
|
||||||
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to connection failure")
|
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to network failure")
|
||||||
return self._load_from_cache(cache_file)
|
return self._load_from_cache(cache_file)
|
||||||
else:
|
else:
|
||||||
raise requests.exceptions.RequestException("No response from Shodan API - should retry")
|
# FIXED: Treat network failures as "no information" rather than retryable errors
|
||||||
|
self.logger.logger.info(f"No Shodan data available for {normalized_ip} due to network failure")
|
||||||
|
result = ProviderResult() # Empty result
|
||||||
|
network_failure_data = {'shodan_status': 'network_unreachable', 'error': 'API unreachable'}
|
||||||
|
self._save_to_cache(cache_file, result, network_failure_data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# FIXED: Handle different status codes more precisely
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self.logger.logger.debug(f"Shodan returned data for {normalized_ip}")
|
self.logger.logger.debug(f"Shodan returned data for {normalized_ip}")
|
||||||
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
result = self._process_shodan_data(normalized_ip, data)
|
result = self._process_shodan_data(normalized_ip, data)
|
||||||
self._save_to_cache(cache_file, result, data)
|
self._save_to_cache(cache_file, result, data)
|
||||||
return result
|
return result
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}: {e}")
|
||||||
|
if cache_status == "stale":
|
||||||
|
return self._load_from_cache(cache_file)
|
||||||
|
else:
|
||||||
|
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
|
||||||
|
|
||||||
elif response.status_code == 404:
|
elif response.status_code == 404:
|
||||||
# 404 = "no information available" - successful but empty result, don't retry
|
# FIXED: 404 = "no information available" - successful but empty result, don't retry
|
||||||
self.logger.logger.debug(f"Shodan has no information for {normalized_ip} (404)")
|
self.logger.logger.debug(f"Shodan has no information for {normalized_ip} (404)")
|
||||||
result = ProviderResult() # Empty but successful result
|
result = ProviderResult() # Empty but successful result
|
||||||
# Cache the empty result to avoid repeated queries
|
# Cache the empty result to avoid repeated queries
|
||||||
self._save_to_cache(cache_file, result, {'shodan_status': 'no_information', 'status_code': 404})
|
empty_data = {'shodan_status': 'no_information', 'status_code': 404}
|
||||||
|
self._save_to_cache(cache_file, result, empty_data)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
elif response.status_code in [401, 403]:
|
elif response.status_code in [401, 403]:
|
||||||
@ -178,7 +225,7 @@ class ShodanProvider(BaseProvider):
|
|||||||
self.logger.logger.error(f"Shodan API authentication failed for {normalized_ip} (HTTP {response.status_code})")
|
self.logger.logger.error(f"Shodan API authentication failed for {normalized_ip} (HTTP {response.status_code})")
|
||||||
return ProviderResult() # Empty result, don't retry
|
return ProviderResult() # Empty result, don't retry
|
||||||
|
|
||||||
elif response.status_code in [429]:
|
elif response.status_code == 429:
|
||||||
# Rate limiting - should be handled by rate limiter, but if we get here, retry
|
# Rate limiting - should be handled by rate limiter, but if we get here, retry
|
||||||
self.logger.logger.warning(f"Shodan API rate limited for {normalized_ip} (HTTP {response.status_code})")
|
self.logger.logger.warning(f"Shodan API rate limited for {normalized_ip} (HTTP {response.status_code})")
|
||||||
if cache_status == "stale":
|
if cache_status == "stale":
|
||||||
@ -197,13 +244,12 @@ class ShodanProvider(BaseProvider):
|
|||||||
raise requests.exceptions.RequestException(f"Shodan API server error (HTTP {response.status_code}) - should retry")
|
raise requests.exceptions.RequestException(f"Shodan API server error (HTTP {response.status_code}) - should retry")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Other HTTP error codes - treat as temporary failures
|
# FIXED: Other HTTP status codes - treat as no information available, don't retry
|
||||||
self.logger.logger.warning(f"Shodan API returned unexpected status {response.status_code} for {normalized_ip}")
|
self.logger.logger.info(f"Shodan returned status {response.status_code} for {normalized_ip} - treating as no information")
|
||||||
if cache_status == "stale":
|
result = ProviderResult() # Empty result
|
||||||
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected API error")
|
no_info_data = {'shodan_status': 'no_information', 'status_code': response.status_code}
|
||||||
return self._load_from_cache(cache_file)
|
self._save_to_cache(cache_file, result, no_info_data)
|
||||||
else:
|
return result
|
||||||
raise requests.exceptions.RequestException(f"Shodan API error (HTTP {response.status_code}) - should retry")
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
# Timeout errors - should be retried
|
# Timeout errors - should be retried
|
||||||
@ -223,17 +269,8 @@ class ShodanProvider(BaseProvider):
|
|||||||
else:
|
else:
|
||||||
raise # Re-raise connection error for retry
|
raise # Re-raise connection error for retry
|
||||||
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
# Other request exceptions - should be retried
|
|
||||||
self.logger.logger.warning(f"Shodan API request exception for {normalized_ip}")
|
|
||||||
if cache_status == "stale":
|
|
||||||
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to request exception")
|
|
||||||
return self._load_from_cache(cache_file)
|
|
||||||
else:
|
|
||||||
raise # Re-raise request exception for retry
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# JSON parsing error on 200 response - treat as temporary failure
|
# JSON parsing error - treat as temporary failure
|
||||||
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}")
|
self.logger.logger.error(f"Invalid JSON response from Shodan for {normalized_ip}")
|
||||||
if cache_status == "stale":
|
if cache_status == "stale":
|
||||||
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to JSON parsing error")
|
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to JSON parsing error")
|
||||||
@ -241,14 +278,16 @@ class ShodanProvider(BaseProvider):
|
|||||||
else:
|
else:
|
||||||
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
|
raise requests.exceptions.RequestException("Invalid JSON response from Shodan - should retry")
|
||||||
|
|
||||||
|
# FIXED: Remove the generic RequestException handler that was causing 404s to retry
|
||||||
|
# Now only specific exceptions that should be retried are re-raised
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Unexpected exceptions - log and treat as temporary failures
|
# FIXED: Unexpected exceptions - log but treat as no information available, don't retry
|
||||||
self.logger.logger.error(f"Unexpected exception in Shodan query for {normalized_ip}: {e}")
|
self.logger.logger.warning(f"Unexpected exception in Shodan query for {normalized_ip}: {e}")
|
||||||
if cache_status == "stale":
|
result = ProviderResult() # Empty result
|
||||||
self.logger.logger.info(f"Using stale cache for {normalized_ip} due to unexpected exception")
|
error_data = {'shodan_status': 'error', 'error': str(e)}
|
||||||
return self._load_from_cache(cache_file)
|
self._save_to_cache(cache_file, result, error_data)
|
||||||
else:
|
return result
|
||||||
raise requests.exceptions.RequestException(f"Unexpected error in Shodan query: {e}") from e
|
|
||||||
|
|
||||||
def _load_from_cache(self, cache_file_path: Path) -> ProviderResult:
|
def _load_from_cache(self, cache_file_path: Path) -> ProviderResult:
|
||||||
"""Load processed Shodan data from a cache file."""
|
"""Load processed Shodan data from a cache file."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user