Compare commits

...

2 Commits

Author SHA1 Message Date
overcuriousity
4c48917993 fixes for scheduler 2025-09-18 21:32:26 +02:00
overcuriousity
9d9afa6a08 fixes 2025-09-18 21:04:29 +02:00
3 changed files with 336 additions and 130 deletions

View File

@ -34,7 +34,7 @@ class Config:
'crtsh': 5,
'shodan': 60,
'dns': 100,
'correlation': 1000 # Set a high limit as it's a local operation
'correlation': 0 # Set to 0 to make sure correlations run last
}
# --- Provider Settings ---

View File

@ -258,20 +258,36 @@ class Scanner:
time.sleep(2)
def start_scan(self, target: str, max_depth: int = 2, clear_graph: bool = True, force_rescan_target: Optional[str] = None) -> bool:
"""
Starts a new reconnaissance scan.
"""
if self.scan_thread and self.scan_thread.is_alive():
self.logger.logger.info("Stopping existing scan before starting new one")
self._set_stop_signal()
self.status = ScanStatus.STOPPED
# Clean up processing state
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
self.task_queue = PriorityQueue()
# Clear task queue
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
# Shutdown executor
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except:
pass
finally:
self.executor = None
self.scan_thread.join(5.0)
# Wait for scan thread to finish (with timeout)
self.scan_thread.join(timeout=5.0)
if self.scan_thread.is_alive():
self.logger.logger.warning("Previous scan thread did not terminate cleanly")
self.status = ScanStatus.IDLE
self.stop_event.clear()
@ -294,6 +310,12 @@ class Scanner:
try:
if not hasattr(self, 'providers') or not self.providers:
self.logger.logger.error("No providers available for scanning")
return False
available_providers = [p for p in self.providers if p.is_available()]
if not available_providers:
self.logger.logger.error("No providers are currently available/configured")
return False
if clear_graph:
@ -301,13 +323,27 @@ class Scanner:
self.initial_targets.clear()
if force_rescan_target and self.graph.graph.has_node(force_rescan_target):
try:
node_data = self.graph.graph.nodes[force_rescan_target]
if 'metadata' in node_data and 'provider_states' in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
self.logger.logger.info(f"Cleared provider states for forced rescan of {force_rescan_target}")
except Exception as e:
self.logger.logger.warning(f"Error clearing provider states for {force_rescan_target}: {e}")
self.current_target = target.lower().strip()
target = target.lower().strip()
if not target:
self.logger.logger.error("Empty target provided")
return False
from utils.helpers import is_valid_target
if not is_valid_target(target):
self.logger.logger.error(f"Invalid target format: {target}")
return False
self.current_target = target
self.initial_targets.add(self.current_target)
self.max_depth = max_depth
self.max_depth = max(1, min(5, max_depth)) # Clamp depth between 1-5
self.current_depth = 0
self.total_indicators_found = 0
@ -320,56 +356,77 @@ class Scanner:
self._update_session_state()
self.logger = new_session()
try:
self.scan_thread = threading.Thread(
target=self._execute_scan,
args=(self.current_target, max_depth),
daemon=True
args=(self.current_target, self.max_depth),
daemon=True,
name=f"ScanThread-{self.session_id or 'default'}"
)
self.scan_thread.start()
self.status_logger_stop_event.clear()
self.status_logger_thread = threading.Thread(target=self._status_logger_thread, daemon=True)
self.status_logger_thread = threading.Thread(
target=self._status_logger_thread,
daemon=True,
name=f"StatusLogger-{self.session_id or 'default'}"
)
self.status_logger_thread.start()
self.logger.logger.info(f"Scan started successfully for {target} with depth {self.max_depth}")
return True
except Exception as e:
self.logger.logger.error(f"Error starting scan threads: {e}")
self.status = ScanStatus.FAILED
self._update_session_state()
return False
except Exception as e:
self.logger.logger.error(f"Error in scan startup: {e}")
traceback.print_exc()
self.status = ScanStatus.FAILED
self._update_session_state()
return False
def _get_priority(self, provider_name):
if provider_name == 'correlation':
return 100 # Highest priority number = lowest priority (runs last)
rate_limit = self.config.get_rate_limit(provider_name)
# Define the logarithmic scale
if rate_limit < 10:
return 10 # Highest priority number (lowest priority) for very low rate limits
# Handle edge cases
if rate_limit <= 0:
return 90 # Very low priority for invalid/disabled providers
# Calculate logarithmic value and map to priority levels
# Lower rate limits get higher priority numbers (lower priority)
log_value = math.log10(rate_limit)
priority = 10 - int(log_value * 2) # Scale factor to get more granular levels
# Ensure priority is within a reasonable range (1-10)
priority = max(1, min(10, priority))
return priority
if provider_name == 'dns':
return 1 # DNS is fastest, should run first
elif provider_name == 'shodan':
return 3 # Shodan is medium speed, good priority
elif provider_name == 'crtsh':
return 5 # crt.sh is slower, lower priority
else:
# For any other providers, use rate limit as a guide
if rate_limit >= 100:
return 2 # High rate limit = high priority
elif rate_limit >= 50:
return 4 # Medium-high rate limit = medium-high priority
elif rate_limit >= 20:
return 6 # Medium rate limit = medium priority
elif rate_limit >= 5:
return 8 # Low rate limit = low priority
else:
return 10 # Very low rate limit = very low priority
def _execute_scan(self, target: str, max_depth: int) -> None:
"""
Execute the reconnaissance scan with a time-based, robust scheduler.
Handles rate-limiting via deferral and failures via exponential backoff.
"""
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_tasks = set()
processed_tasks = set() # FIXED: Now includes depth to avoid incorrect skipping
is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False)
for provider in initial_providers:
provider_name = provider.get_name()
priority = self._get_priority(provider_name)
# OVERHAUL: Enqueue with current timestamp to run immediately
self.task_queue.put((time.time(), priority, (provider_name, target, 0)))
self.total_tasks_ever_enqueued += 1
@ -383,101 +440,156 @@ class Scanner:
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, node_type)
self._initialize_provider_states(target)
consecutive_empty_iterations = 0
max_empty_iterations = 50 # Allow 5 seconds of empty queue before considering completion
while not self._is_stop_requested():
if self.task_queue.empty() and not self.currently_processing:
break # Scan is complete
queue_empty = self.task_queue.empty()
with self.processing_lock:
no_active_processing = len(self.currently_processing) == 0
if queue_empty and no_active_processing:
consecutive_empty_iterations += 1
if consecutive_empty_iterations >= max_empty_iterations:
break # Scan is complete
time.sleep(0.1)
continue
else:
consecutive_empty_iterations = 0
# FIXED: Safe task retrieval without race conditions
try:
# OVERHAUL: Peek at the next task to see if it's ready to run
next_run_at, _, _ = self.task_queue.queue[0]
if next_run_at > time.time():
time.sleep(0.1) # Sleep to prevent busy-waiting for future tasks
# Use timeout to avoid blocking indefinitely
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get(timeout=0.1)
# FIXED: Check if task is ready to run
current_time = time.time()
if run_at > current_time:
# Task is not ready yet, re-queue it and continue
self.task_queue.put((run_at, priority, (provider_name, target_item, depth)))
time.sleep(min(0.5, run_at - current_time)) # Sleep until closer to run time
continue
except: # Queue is empty or timeout occurred
time.sleep(0.1)
continue
# Task is ready, so get it from the queue
run_at, priority, (provider_name, target_item, depth) = self.task_queue.get()
self.last_task_from_queue = (run_at, priority, (provider_name, target_item, depth))
except IndexError:
time.sleep(0.1) # Queue is empty, but tasks might still be processing
continue
task_tuple = (provider_name, target_item)
# FIXED: Include depth in processed tasks to avoid incorrect skipping
task_tuple = (provider_name, target_item, depth)
if task_tuple in processed_tasks:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# FIXED: Proper depth checking
if depth > max_depth:
self.tasks_skipped += 1
self.indicators_completed += 1
continue
# OVERHAUL: Handle rate limiting with time-based deferral
# FIXED: Rate limiting with proper time-based deferral
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
defer_until = time.time() + 60 # Defer for 60 seconds
self.task_queue.put((defer_until, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
continue
# FIXED: Thread-safe processing state management
with self.processing_lock:
if self._is_stop_requested(): break
self.currently_processing.add(task_tuple)
if self._is_stop_requested():
break
# Use provider+target (without depth) for duplicate processing check
processing_key = (provider_name, target_item)
if processing_key in self.currently_processing:
# Already processing this provider+target combination, skip
self.tasks_skipped += 1
self.indicators_completed += 1
continue
self.currently_processing.add(processing_key)
try:
self.current_depth = depth
self.current_indicator = target_item
self._update_session_state()
if self._is_stop_requested(): break
if self._is_stop_requested():
break
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider:
new_targets, _, success = self._process_provider_task(provider, target_item, depth)
if self._is_stop_requested(): break
if self._is_stop_requested():
break
if not success:
self.target_retries[task_tuple] += 1
if self.target_retries[task_tuple] <= self.config.max_retries_per_target:
# OVERHAUL: Exponential backoff for retries
retry_count = self.target_retries[task_tuple]
backoff_delay = (2 ** retry_count) + random.uniform(0, 1) # Add jitter
# FIXED: Use depth-aware retry key
retry_key = (provider_name, target_item, depth)
self.target_retries[retry_key] += 1
if self.target_retries[retry_key] <= self.config.max_retries_per_target:
# FIXED: Exponential backoff with jitter for retries
retry_count = self.target_retries[retry_key]
backoff_delay = min(300, (2 ** retry_count) + random.uniform(0, 1)) # Cap at 5 minutes
retry_at = time.time() + backoff_delay
self.task_queue.put((retry_at, priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
self.logger.logger.debug(f"Retrying {provider_name}:{target_item} in {backoff_delay:.1f}s (attempt {retry_count})")
else:
self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), "Max retries exceeded")
self._log_target_processing_error(str(task_tuple), f"Max retries ({self.config.max_retries_per_target}) exceeded")
else:
processed_tasks.add(task_tuple)
self.indicators_completed += 1
# FIXED: Enqueue new targets with proper depth tracking
if not self._is_stop_requested():
for new_target in new_targets:
is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
for p_new in eligible_providers_new:
p_name_new = p_new.get_name()
if (p_name_new, new_target) not in processed_tasks:
new_depth = depth + 1 if new_target in new_targets else depth
new_depth = depth + 1 # Always increment depth for discovered targets
new_task_tuple = (p_name_new, new_target, new_depth)
# FIXED: Don't re-enqueue already processed tasks
if new_task_tuple not in processed_tasks and new_depth <= max_depth:
new_priority = self._get_priority(p_name_new)
# OVERHAUL: Enqueue new tasks to run immediately
# Enqueue new tasks to run immediately
self.task_queue.put((time.time(), new_priority, (p_name_new, new_target, new_depth)))
self.total_tasks_ever_enqueued += 1
else:
self.logger.logger.warning(f"Provider {provider_name} not found in active providers")
self.tasks_skipped += 1
self.indicators_completed += 1
finally:
# FIXED: Always clean up processing state
with self.processing_lock:
self.currently_processing.discard(task_tuple)
processing_key = (provider_name, target_item)
self.currently_processing.discard(processing_key)
except Exception as e:
traceback.print_exc()
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
# FIXED: Comprehensive cleanup
with self.processing_lock:
self.currently_processing.clear()
self.currently_processing_display = []
# FIXED: Clear any remaining tasks from queue to prevent memory leaks
while not self.task_queue.empty():
try:
self.task_queue.get_nowait()
except:
break
if self._is_stop_requested():
self.status = ScanStatus.STOPPED
elif self.scan_failed_due_to_retries:
@ -486,13 +598,18 @@ class Scanner:
self.status = ScanStatus.COMPLETED
self.status_logger_stop_event.set()
if self.status_logger_thread:
self.status_logger_thread.join()
if self.status_logger_thread and self.status_logger_thread.is_alive():
self.status_logger_thread.join(timeout=2.0) # Don't wait forever
self._update_session_state()
self.logger.log_scan_complete()
if self.executor:
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception as e:
self.logger.logger.warning(f"Error shutting down executor: {e}")
finally:
self.executor = None
def _process_provider_task(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
@ -581,29 +698,6 @@ class Scanner:
if self._is_stop_requested():
return discovered_targets, False
# Process all attributes first, grouping by target node
attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes:
attr_dict = {
"name": attribute.name,
"value": attribute.value,
"type": attribute.type,
"provider": attribute.provider,
"confidence": attribute.confidence,
"metadata": attribute.metadata
}
attributes_by_node[attribute.target_node].append(attr_dict)
# FIXED: Add attributes to existing nodes AND create new nodes (like correlation nodes)
for node_id, node_attributes_list in attributes_by_node.items():
if provider_name == 'correlation' and not self.graph.graph.has_node(node_id):
node_type = NodeType.CORRELATION_OBJECT
elif _is_valid_ip(node_id):
node_type = NodeType.IP
else:
node_type = NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
# Check if this should be a large entity
if provider_result.get_relationship_count() > self.config.large_entity_threshold:
members = self._create_large_entity_from_provider_result(target, provider_name, provider_result, current_depth)
@ -654,6 +748,30 @@ class Scanner:
if (_is_valid_domain(target_node) or _is_valid_ip(target_node)) and not max_depth_reached:
discovered_targets.add(target_node)
# Process all attributes, grouping by target node
attributes_by_node = defaultdict(list)
for attribute in provider_result.attributes:
attr_dict = {
"name": attribute.name,
"value": attribute.value,
"type": attribute.type,
"provider": attribute.provider,
"confidence": attribute.confidence,
"metadata": attribute.metadata
}
attributes_by_node[attribute.target_node].append(attr_dict)
# Add attributes to existing nodes OR create new nodes if they don't exist
for node_id, node_attributes_list in attributes_by_node.items():
if not self.graph.graph.has_node(node_id):
# If the node doesn't exist, create it with a default type
node_type = NodeType.IP if _is_valid_ip(node_id) else NodeType.DOMAIN
self.graph.add_node(node_id, node_type, attributes=node_attributes_list)
else:
# If the node already exists, just add the attributes
node_type_val = self.graph.graph.nodes[node_id].get('type', 'domain')
self.graph.add_node(node_id, NodeType(node_type_val), attributes=node_attributes_list)
return discovered_targets, False
def _create_large_entity_from_provider_result(self, source: str, provider_name: str,
@ -836,47 +954,108 @@ class Scanner:
return { 'status': 'error', 'message': 'Failed to get status' }
def _initialize_provider_states(self, target: str) -> None:
"""Initialize provider states for forensic tracking."""
if not self.graph.graph.has_node(target): return
"""
FIXED: Safer provider state initialization with error handling.
"""
try:
if not self.graph.graph.has_node(target):
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data: node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {}
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
except Exception as e:
self.logger.logger.warning(f"Error initializing provider states for {target}: {e}")
def _get_eligible_providers(self, target: str, is_ip: bool, dns_only: bool) -> List:
"""Get providers eligible for querying this target."""
"""
FIXED: Improved provider eligibility checking with better filtering.
"""
if dns_only:
return [p for p in self.providers if p.get_name() == 'dns']
eligible = []
target_key = 'ips' if is_ip else 'domains'
for provider in self.providers:
if provider.get_eligibility().get(target_key):
try:
# Check if provider supports this target type
if not provider.get_eligibility().get(target_key, False):
continue
# Check if provider is available/configured
if not provider.is_available():
continue
# Check if we already successfully queried this provider
if not self._already_queried_provider(target, provider.get_name()):
eligible.append(provider)
except Exception as e:
self.logger.logger.warning(f"Error checking provider eligibility {provider.get_name()}: {e}")
continue
return eligible
def _already_queried_provider(self, target: str, provider_name: str) -> bool:
"""Check if we already successfully queried a provider for a target."""
if not self.graph.graph.has_node(target): return False
"""
FIXED: More robust check for already queried providers with proper error handling.
"""
try:
if not self.graph.graph.has_node(target):
return False
node_data = self.graph.graph.nodes[target]
provider_states = node_data.get('metadata', {}).get('provider_states', {})
provider_state = provider_states.get(provider_name)
return provider_state is not None and provider_state.get('status') == 'success'
# Only consider it already queried if it was successful
return (provider_state is not None and
provider_state.get('status') == 'success' and
provider_state.get('results_count', 0) > 0)
except Exception as e:
self.logger.logger.warning(f"Error checking provider state for {target}:{provider_name}: {e}")
return False
def _update_provider_state(self, target: str, provider_name: str, status: str,
results_count: int, error: Optional[str], start_time: datetime) -> None:
"""Update provider state in node metadata for forensic tracking."""
if not self.graph.graph.has_node(target): return
"""
FIXED: More robust provider state updates with validation.
"""
try:
if not self.graph.graph.has_node(target):
self.logger.logger.warning(f"Cannot update provider state: node {target} not found")
return
node_data = self.graph.graph.nodes[target]
if 'metadata' not in node_data: node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']: node_data['metadata']['provider_states'] = {}
if 'metadata' not in node_data:
node_data['metadata'] = {}
if 'provider_states' not in node_data['metadata']:
node_data['metadata']['provider_states'] = {}
# Calculate duration safely
try:
duration_ms = (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
except Exception:
duration_ms = 0
node_data['metadata']['provider_states'][provider_name] = {
'status': status,
'timestamp': start_time.isoformat(),
'results_count': results_count,
'error': error,
'duration_ms': (datetime.now(timezone.utc) - start_time).total_seconds() * 1000
'results_count': max(0, results_count), # Ensure non-negative
'error': str(error) if error else None,
'duration_ms': duration_ms
}
# Update last modified time for forensic integrity
self.last_modified = datetime.now(timezone.utc).isoformat()
except Exception as e:
self.logger.logger.error(f"Error updating provider state for {target}:{provider_name}: {e}")
def _log_target_processing_error(self, target: str, error: str) -> None:
self.logger.logger.error(f"Target processing failed for {target}: {error}")
@ -884,8 +1063,28 @@ class Scanner:
self.logger.logger.error(f"Provider {provider_name} failed for {target}: {error}")
def _calculate_progress(self) -> float:
if self.total_tasks_ever_enqueued == 0: return 0.0
return min(100.0, (self.indicators_completed / self.total_tasks_ever_enqueued) * 100)
try:
if self.total_tasks_ever_enqueued == 0:
return 0.0
# Add small buffer for tasks still in queue to avoid showing 100% too early
queue_size = max(0, self.task_queue.qsize())
with self.processing_lock:
active_tasks = len(self.currently_processing)
# Adjust total to account for remaining work
adjusted_total = max(self.total_tasks_ever_enqueued,
self.indicators_completed + queue_size + active_tasks)
if adjusted_total == 0:
return 100.0
progress = (self.indicators_completed / adjusted_total) * 100
return max(0.0, min(100.0, progress)) # Clamp between 0 and 100
except Exception as e:
self.logger.logger.warning(f"Error calculating progress: {e}")
return 0.0
def get_graph_data(self) -> Dict[str, Any]:
graph_data = self.graph.get_graph_data()

View File

@ -39,6 +39,7 @@ class ShodanProvider(BaseProvider):
return False
try:
response = self.session.get(f"{self.base_url}/api-info?key={self.api_key}", timeout=5)
self.logger.logger.debug("Shodan is reacheable")
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@ -107,6 +108,12 @@ class ShodanProvider(BaseProvider):
except (json.JSONDecodeError, ValueError, KeyError):
return "stale"
def query_domain(self, domain: str) -> ProviderResult:
"""
Shodan does not support domain queries. This method returns an empty result.
"""
return ProviderResult()
def query_ip(self, ip: str) -> ProviderResult:
"""
Query Shodan for information about an IP address (IPv4 or IPv6), with caching of processed data.