new highest-priority-first scheduler

This commit is contained in:
overcuriousity
2025-09-15 22:21:17 +02:00
parent 71b2855d01
commit 4a5ecf7a37
46 changed files with 194 additions and 166 deletions

View File

@@ -50,7 +50,7 @@ class ForensicLogger:
session_id: Unique identifier for this reconnaissance session
"""
self.session_id = session_id or self._generate_session_id()
#self.lock = threading.Lock()
self.lock = threading.Lock()
# Initialize audit trail storage
self.api_requests: List[APIRequest] = []
@@ -86,6 +86,8 @@ class ForensicLogger:
# Remove the unpickleable 'logger' attribute
if 'logger' in state:
del state['logger']
if 'lock' in state:
del state['lock']
return state
def __setstate__(self, state):
@@ -101,6 +103,7 @@ class ForensicLogger:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.lock = threading.Lock()
def _generate_session_id(self) -> str:
"""Generate unique session identifier."""

29
core/rate_limiter.py Normal file
View File

@@ -0,0 +1,29 @@
# dnsrecon-reduced/core/rate_limiter.py
import time
import redis
class GlobalRateLimiter:
def __init__(self, redis_client):
self.redis = redis_client
def is_rate_limited(self, key, limit, period):
"""
Check if a key is rate-limited.
"""
now = time.time()
key = f"rate_limit:{key}"
# Remove old timestamps
self.redis.zremrangebyscore(key, 0, now - period)
# Check the count
count = self.redis.zcard(key)
if count >= limit:
return True
# Add new timestamp
self.redis.zadd(key, {now: now})
self.redis.expire(key, period)
return False

View File

@@ -5,16 +5,18 @@ import traceback
import time
import os
import importlib
import redis
from typing import List, Set, Dict, Any, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future
from collections import defaultdict, deque
from collections import defaultdict
from queue import PriorityQueue
from datetime import datetime, timezone
from core.graph_manager import GraphManager, NodeType
from core.logger import get_forensic_logger, new_session
from utils.helpers import _is_valid_ip, _is_valid_domain
from providers.base_provider import BaseProvider
from core.rate_limiter import GlobalRateLimiter
class ScanStatus:
"""Enumeration of scan statuses."""
@@ -50,7 +52,7 @@ class Scanner:
self.stop_event = threading.Event()
self.scan_thread = None
self.session_id: Optional[str] = None # Will be set by session manager
self.task_queue = deque([])
self.task_queue = PriorityQueue()
self.target_retries = defaultdict(int)
self.scan_failed_due_to_retries = False
@@ -76,6 +78,9 @@ class Scanner:
# Initialize logger
print("Initializing forensic logger...")
self.logger = get_forensic_logger()
# Initialize global rate limiter
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
print("Scanner initialization complete")
@@ -129,7 +134,10 @@ class Scanner:
'stop_event',
'scan_thread',
'executor',
'processing_lock' # **NEW**: Exclude the processing lock
'processing_lock', # **NEW**: Exclude the processing lock
'task_queue', # PriorityQueue is not picklable
'rate_limiter',
'logger'
]
for attr in unpicklable_attrs:
@@ -154,6 +162,9 @@ class Scanner:
self.scan_thread = None
self.executor = None
self.processing_lock = threading.Lock() # **NEW**: Recreate processing lock
self.task_queue = PriorityQueue()
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
self.logger = get_forensic_logger()
# **NEW**: Reset processing tracking
if not hasattr(self, 'currently_processing'):
@@ -221,7 +232,7 @@ class Scanner:
# Clear all processing state
with self.processing_lock:
self.currently_processing.clear()
self.task_queue.clear()
self.task_queue = PriorityQueue()
# Shutdown executor aggressively
if self.executor:
@@ -251,7 +262,7 @@ class Scanner:
with self.processing_lock:
self.currently_processing.clear()
self.task_queue.clear()
self.task_queue = PriorityQueue()
self.target_retries.clear()
self.scan_failed_due_to_retries = False
@@ -311,42 +322,60 @@ class Scanner:
self._update_session_state()
return False
def _get_priority(self, provider_name):
rate_limit = self.config.get_rate_limit(provider_name)
if rate_limit > 90:
return 1 # Highest priority
elif rate_limit > 50:
return 2
else:
return 3 # Lowest priority
def _execute_scan(self, target: str, max_depth: int) -> None:
"""Execute the reconnaissance scan with proper termination handling."""
print(f"_execute_scan started for {target} with depth {max_depth}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_targets = set()
self.task_queue.append((target, 0, False))
processed_tasks = set()
# Initial task population for the main target
is_ip = _is_valid_ip(target)
initial_providers = self._get_eligible_providers(target, is_ip, False)
for provider in initial_providers:
provider_name = provider.get_name()
self.task_queue.put((self._get_priority(provider_name), (provider_name, target, 0)))
try:
self.status = ScanStatus.RUNNING
self._update_session_state()
enabled_providers = [provider.get_name() for provider in self.providers]
self.logger.log_scan_start(target, max_depth, enabled_providers)
# Determine initial node type
node_type = NodeType.IP if _is_valid_ip(target) else NodeType.DOMAIN
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
self.graph.add_node(target, node_type)
self._initialize_provider_states(target)
# Better termination checking in main loop
while self.task_queue and not self._is_stop_requested():
while not self.task_queue.empty() and not self._is_stop_requested():
try:
target_item, depth, is_large_entity_member = self.task_queue.popleft()
priority, (provider_name, target_item, depth) = self.task_queue.get()
except IndexError:
# Queue became empty during processing
break
if target_item in processed_targets:
task_tuple = (provider_name, target_item)
if task_tuple in processed_tasks:
continue
if depth > max_depth:
continue
# Track this target as currently processing
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
self.task_queue.put((priority + 1, (provider_name, target_item, depth))) # Postpone
continue
with self.processing_lock:
if self._is_stop_requested():
print(f"Stop requested before processing {target_item}")
@@ -357,53 +386,52 @@ class Scanner:
self.current_depth = depth
self.current_indicator = target_item
self._update_session_state()
# More frequent stop checking during processing
if self._is_stop_requested():
print(f"Stop requested during processing setup for {target_item}")
break
new_targets, large_entity_members, success = self._query_providers_for_target(target_item, depth, is_large_entity_member)
# Check stop signal after provider queries
if self._is_stop_requested():
print(f"Stop requested after querying providers for {target_item}")
break
if not success:
self.target_retries[target_item] += 1
if self.target_retries[target_item] <= self.config.max_retries_per_target:
print(f"Re-queueing target {target_item} (attempt {self.target_retries[target_item]})")
self.task_queue.append((target_item, depth, is_large_entity_member))
self.tasks_re_enqueued += 1
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
if provider:
new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth)
if self._is_stop_requested():
print(f"Stop requested after querying providers for {target_item}")
break
if not success:
self.target_retries[task_tuple] += 1
if self.target_retries[task_tuple] <= self.config.max_retries_per_target:
print(f"Re-queueing task {task_tuple} (attempt {self.target_retries[task_tuple]})")
self.task_queue.put((priority, (provider_name, target_item, depth)))
self.tasks_re_enqueued += 1
else:
print(f"ERROR: Max retries exceeded for task {task_tuple}")
self.scan_failed_due_to_retries = True
self._log_target_processing_error(str(task_tuple), "Max retries exceeded")
else:
print(f"ERROR: Max retries exceeded for target {target_item}")
self.scan_failed_due_to_retries = True
self._log_target_processing_error(target_item, "Max retries exceeded")
else:
processed_targets.add(target_item)
self.indicators_completed += 1
# Only add new targets if not stopped
if not self._is_stop_requested():
for new_target in new_targets:
if new_target not in processed_targets:
self.task_queue.append((new_target, depth + 1, False))
for member in large_entity_members:
if member not in processed_targets:
self.task_queue.append((member, depth, True))
processed_tasks.add(task_tuple)
self.indicators_completed += 1
if not self._is_stop_requested():
all_new_targets = new_targets.union(large_entity_members)
for new_target in all_new_targets:
is_ip_new = _is_valid_ip(new_target)
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
for p_new in eligible_providers_new:
p_name_new = p_new.get_name()
if (p_name_new, new_target) not in processed_tasks:
new_depth = depth + 1 if new_target in new_targets else depth
self.task_queue.put((self._get_priority(p_name_new), (p_name_new, new_target, new_depth)))
finally:
# Always remove from processing set
with self.processing_lock:
self.currently_processing.discard(target_item)
# Log termination reason
if self._is_stop_requested():
print("Scan terminated due to stop request")
self.logger.logger.info("Scan terminated by user request")
elif not self.task_queue:
elif self.task_queue.empty():
print("Scan completed - no more targets to process")
self.logger.logger.info("Scan completed - all targets processed")
@@ -413,17 +441,16 @@ class Scanner:
self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}")
finally:
# Clear processing state on exit
with self.processing_lock:
self.currently_processing.clear()
if self._is_stop_requested():
self.status = ScanStatus.STOPPED
elif self.scan_failed_due_to_retries:
self.status = ScanStatus.FAILED
else:
self.status = ScanStatus.COMPLETED
self._update_session_state()
self.logger.log_scan_complete()
if self.executor:
@@ -433,60 +460,43 @@ class Scanner:
print("Final scan statistics:")
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
print(f" - Targets processed: {len(processed_targets)}")
print(f" - Tasks processed: {len(processed_tasks)}")
def _query_providers_for_target(self, target: str, depth: int, dns_only: bool = False) -> Tuple[Set[str], Set[str], bool]:
"""Query providers for a single target with enhanced stop checking."""
# **NEW**: Early termination check
def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
if self._is_stop_requested():
print(f"Stop requested before querying providers for {target}")
print(f"Stop requested before querying {provider.get_name()} for {target}")
return set(), set(), False
is_ip = _is_valid_ip(target)
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
print(f"Querying providers for {target_type.value}: {target} at depth {depth}")
print(f"Querying {provider.get_name()} for {target_type.value}: {target} at depth {depth}")
self.graph.add_node(target, target_type)
self._initialize_provider_states(target)
new_targets = set()
large_entity_members = set()
node_attributes = defaultdict(lambda: defaultdict(list))
all_providers_successful = True
provider_successful = True
eligible_providers = self._get_eligible_providers(target, is_ip, dns_only)
if not eligible_providers:
self._log_no_eligible_providers(target, is_ip)
return new_targets, large_entity_members, True
# **IMPROVED**: Check stop signal before each provider
for i, provider in enumerate(eligible_providers):
if self._is_stop_requested():
print(f"Stop requested while querying provider {i+1}/{len(eligible_providers)} for {target}")
all_providers_successful = False
break
try:
provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth)
if provider_results is None:
all_providers_successful = False
elif not self._is_stop_requested():
discovered, is_large_entity = self._process_provider_results_forensic(
target, provider, provider_results, node_attributes, depth
)
if is_large_entity:
large_entity_members.update(discovered)
else:
new_targets.update(discovered)
try:
provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth)
if provider_results is None:
provider_successful = False
elif not self._is_stop_requested():
discovered, is_large_entity = self._process_provider_results_forensic(
target, provider, provider_results, node_attributes, depth
)
if is_large_entity:
large_entity_members.update(discovered)
else:
print(f"Stop requested after processing results from {provider.get_name()}")
break
except Exception as e:
all_providers_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
new_targets.update(discovered)
else:
print(f"Stop requested after processing results from {provider.get_name()}")
except Exception as e:
provider_successful = False
self._log_provider_error(target, provider.get_name(), str(e))
# **NEW**: Only update node attributes if not stopped
if not self._is_stop_requested():
for node_id, attributes in node_attributes.items():
if self.graph.graph.has_node(node_id):
@@ -494,7 +504,7 @@ class Scanner:
node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN
self.graph.add_node(node_id, node_type_to_add, attributes=attributes)
return new_targets, large_entity_members, all_providers_successful
return new_targets, large_entity_members, provider_successful
def stop_scan(self) -> bool:
"""Request immediate scan termination with proper cleanup."""
@@ -513,8 +523,10 @@ class Scanner:
print(f"Cleared {len(currently_processing_copy)} currently processing targets: {currently_processing_copy}")
# **IMPROVED**: Clear task queue and log what was discarded
discarded_tasks = list(self.task_queue)
self.task_queue.clear()
discarded_tasks = []
while not self.task_queue.empty():
discarded_tasks.append(self.task_queue.get())
self.task_queue = PriorityQueue()
print(f"Discarded {len(discarded_tasks)} pending tasks")
# **IMPROVED**: Aggressively shut down executor
@@ -572,7 +584,7 @@ class Scanner:
'progress_percentage': self._calculate_progress(),
'enabled_providers': [provider.get_name() for provider in self.providers],
'graph_statistics': self.graph.get_statistics(),
'task_queue_size': len(self.task_queue),
'task_queue_size': self.task_queue.qsize(),
'currently_processing_count': currently_processing_count, # **NEW**
'currently_processing': currently_processing_list[:5] # **NEW**: Show first 5 for debugging
}
@@ -859,7 +871,7 @@ class Scanner:
def _calculate_progress(self) -> float:
"""Calculate scan progress percentage based on task completion."""
total_tasks = self.indicators_completed + len(self.task_queue)
total_tasks = self.indicators_completed + self.task_queue.qsize()
if total_tasks == 0:
return 0.0
return min(100.0, (self.indicators_completed / total_tasks) * 100)