new highest-priority-first scheduler
This commit is contained in:
@@ -50,7 +50,7 @@ class ForensicLogger:
|
||||
session_id: Unique identifier for this reconnaissance session
|
||||
"""
|
||||
self.session_id = session_id or self._generate_session_id()
|
||||
#self.lock = threading.Lock()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Initialize audit trail storage
|
||||
self.api_requests: List[APIRequest] = []
|
||||
@@ -86,6 +86,8 @@ class ForensicLogger:
|
||||
# Remove the unpickleable 'logger' attribute
|
||||
if 'logger' in state:
|
||||
del state['logger']
|
||||
if 'lock' in state:
|
||||
del state['lock']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -101,6 +103,7 @@ class ForensicLogger:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _generate_session_id(self) -> str:
|
||||
"""Generate unique session identifier."""
|
||||
|
||||
29
core/rate_limiter.py
Normal file
29
core/rate_limiter.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# dnsrecon-reduced/core/rate_limiter.py
|
||||
|
||||
import time
|
||||
import redis
|
||||
|
||||
class GlobalRateLimiter:
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
def is_rate_limited(self, key, limit, period):
|
||||
"""
|
||||
Check if a key is rate-limited.
|
||||
"""
|
||||
now = time.time()
|
||||
key = f"rate_limit:{key}"
|
||||
|
||||
# Remove old timestamps
|
||||
self.redis.zremrangebyscore(key, 0, now - period)
|
||||
|
||||
# Check the count
|
||||
count = self.redis.zcard(key)
|
||||
if count >= limit:
|
||||
return True
|
||||
|
||||
# Add new timestamp
|
||||
self.redis.zadd(key, {now: now})
|
||||
self.redis.expire(key, period)
|
||||
|
||||
return False
|
||||
212
core/scanner.py
212
core/scanner.py
@@ -5,16 +5,18 @@ import traceback
|
||||
import time
|
||||
import os
|
||||
import importlib
|
||||
import redis
|
||||
from typing import List, Set, Dict, Any, Tuple, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, CancelledError, Future
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict
|
||||
from queue import PriorityQueue
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.graph_manager import GraphManager, NodeType
|
||||
from core.logger import get_forensic_logger, new_session
|
||||
from utils.helpers import _is_valid_ip, _is_valid_domain
|
||||
from providers.base_provider import BaseProvider
|
||||
|
||||
from core.rate_limiter import GlobalRateLimiter
|
||||
|
||||
class ScanStatus:
|
||||
"""Enumeration of scan statuses."""
|
||||
@@ -50,7 +52,7 @@ class Scanner:
|
||||
self.stop_event = threading.Event()
|
||||
self.scan_thread = None
|
||||
self.session_id: Optional[str] = None # Will be set by session manager
|
||||
self.task_queue = deque([])
|
||||
self.task_queue = PriorityQueue()
|
||||
self.target_retries = defaultdict(int)
|
||||
self.scan_failed_due_to_retries = False
|
||||
|
||||
@@ -76,6 +78,9 @@ class Scanner:
|
||||
# Initialize logger
|
||||
print("Initializing forensic logger...")
|
||||
self.logger = get_forensic_logger()
|
||||
|
||||
# Initialize global rate limiter
|
||||
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
||||
|
||||
print("Scanner initialization complete")
|
||||
|
||||
@@ -129,7 +134,10 @@ class Scanner:
|
||||
'stop_event',
|
||||
'scan_thread',
|
||||
'executor',
|
||||
'processing_lock' # **NEW**: Exclude the processing lock
|
||||
'processing_lock', # **NEW**: Exclude the processing lock
|
||||
'task_queue', # PriorityQueue is not picklable
|
||||
'rate_limiter',
|
||||
'logger'
|
||||
]
|
||||
|
||||
for attr in unpicklable_attrs:
|
||||
@@ -154,6 +162,9 @@ class Scanner:
|
||||
self.scan_thread = None
|
||||
self.executor = None
|
||||
self.processing_lock = threading.Lock() # **NEW**: Recreate processing lock
|
||||
self.task_queue = PriorityQueue()
|
||||
self.rate_limiter = GlobalRateLimiter(redis.StrictRedis(db=0))
|
||||
self.logger = get_forensic_logger()
|
||||
|
||||
# **NEW**: Reset processing tracking
|
||||
if not hasattr(self, 'currently_processing'):
|
||||
@@ -221,7 +232,7 @@ class Scanner:
|
||||
# Clear all processing state
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
self.task_queue.clear()
|
||||
self.task_queue = PriorityQueue()
|
||||
|
||||
# Shutdown executor aggressively
|
||||
if self.executor:
|
||||
@@ -251,7 +262,7 @@ class Scanner:
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
|
||||
self.task_queue.clear()
|
||||
self.task_queue = PriorityQueue()
|
||||
self.target_retries.clear()
|
||||
self.scan_failed_due_to_retries = False
|
||||
|
||||
@@ -311,42 +322,60 @@ class Scanner:
|
||||
self._update_session_state()
|
||||
return False
|
||||
|
||||
def _get_priority(self, provider_name):
|
||||
rate_limit = self.config.get_rate_limit(provider_name)
|
||||
if rate_limit > 90:
|
||||
return 1 # Highest priority
|
||||
elif rate_limit > 50:
|
||||
return 2
|
||||
else:
|
||||
return 3 # Lowest priority
|
||||
|
||||
def _execute_scan(self, target: str, max_depth: int) -> None:
|
||||
"""Execute the reconnaissance scan with proper termination handling."""
|
||||
print(f"_execute_scan started for {target} with depth {max_depth}")
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
processed_targets = set()
|
||||
|
||||
self.task_queue.append((target, 0, False))
|
||||
|
||||
processed_tasks = set()
|
||||
|
||||
# Initial task population for the main target
|
||||
is_ip = _is_valid_ip(target)
|
||||
initial_providers = self._get_eligible_providers(target, is_ip, False)
|
||||
for provider in initial_providers:
|
||||
provider_name = provider.get_name()
|
||||
self.task_queue.put((self._get_priority(provider_name), (provider_name, target, 0)))
|
||||
|
||||
try:
|
||||
self.status = ScanStatus.RUNNING
|
||||
self._update_session_state()
|
||||
|
||||
|
||||
enabled_providers = [provider.get_name() for provider in self.providers]
|
||||
self.logger.log_scan_start(target, max_depth, enabled_providers)
|
||||
|
||||
|
||||
# Determine initial node type
|
||||
node_type = NodeType.IP if _is_valid_ip(target) else NodeType.DOMAIN
|
||||
node_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
||||
self.graph.add_node(target, node_type)
|
||||
|
||||
|
||||
self._initialize_provider_states(target)
|
||||
|
||||
# Better termination checking in main loop
|
||||
while self.task_queue and not self._is_stop_requested():
|
||||
while not self.task_queue.empty() and not self._is_stop_requested():
|
||||
try:
|
||||
target_item, depth, is_large_entity_member = self.task_queue.popleft()
|
||||
priority, (provider_name, target_item, depth) = self.task_queue.get()
|
||||
except IndexError:
|
||||
# Queue became empty during processing
|
||||
break
|
||||
|
||||
if target_item in processed_targets:
|
||||
task_tuple = (provider_name, target_item)
|
||||
if task_tuple in processed_tasks:
|
||||
continue
|
||||
|
||||
if depth > max_depth:
|
||||
continue
|
||||
|
||||
# Track this target as currently processing
|
||||
if self.rate_limiter.is_rate_limited(provider_name, self.config.get_rate_limit(provider_name), 60):
|
||||
self.task_queue.put((priority + 1, (provider_name, target_item, depth))) # Postpone
|
||||
continue
|
||||
|
||||
with self.processing_lock:
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested before processing {target_item}")
|
||||
@@ -357,53 +386,52 @@ class Scanner:
|
||||
self.current_depth = depth
|
||||
self.current_indicator = target_item
|
||||
self._update_session_state()
|
||||
|
||||
# More frequent stop checking during processing
|
||||
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested during processing setup for {target_item}")
|
||||
break
|
||||
|
||||
new_targets, large_entity_members, success = self._query_providers_for_target(target_item, depth, is_large_entity_member)
|
||||
|
||||
# Check stop signal after provider queries
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested after querying providers for {target_item}")
|
||||
break
|
||||
|
||||
if not success:
|
||||
self.target_retries[target_item] += 1
|
||||
if self.target_retries[target_item] <= self.config.max_retries_per_target:
|
||||
print(f"Re-queueing target {target_item} (attempt {self.target_retries[target_item]})")
|
||||
self.task_queue.append((target_item, depth, is_large_entity_member))
|
||||
self.tasks_re_enqueued += 1
|
||||
provider = next((p for p in self.providers if p.get_name() == provider_name), None)
|
||||
|
||||
if provider:
|
||||
new_targets, large_entity_members, success = self._query_single_provider_for_target(provider, target_item, depth)
|
||||
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested after querying providers for {target_item}")
|
||||
break
|
||||
|
||||
if not success:
|
||||
self.target_retries[task_tuple] += 1
|
||||
if self.target_retries[task_tuple] <= self.config.max_retries_per_target:
|
||||
print(f"Re-queueing task {task_tuple} (attempt {self.target_retries[task_tuple]})")
|
||||
self.task_queue.put((priority, (provider_name, target_item, depth)))
|
||||
self.tasks_re_enqueued += 1
|
||||
else:
|
||||
print(f"ERROR: Max retries exceeded for task {task_tuple}")
|
||||
self.scan_failed_due_to_retries = True
|
||||
self._log_target_processing_error(str(task_tuple), "Max retries exceeded")
|
||||
else:
|
||||
print(f"ERROR: Max retries exceeded for target {target_item}")
|
||||
self.scan_failed_due_to_retries = True
|
||||
self._log_target_processing_error(target_item, "Max retries exceeded")
|
||||
else:
|
||||
processed_targets.add(target_item)
|
||||
self.indicators_completed += 1
|
||||
|
||||
# Only add new targets if not stopped
|
||||
if not self._is_stop_requested():
|
||||
for new_target in new_targets:
|
||||
if new_target not in processed_targets:
|
||||
self.task_queue.append((new_target, depth + 1, False))
|
||||
|
||||
for member in large_entity_members:
|
||||
if member not in processed_targets:
|
||||
self.task_queue.append((member, depth, True))
|
||||
|
||||
processed_tasks.add(task_tuple)
|
||||
self.indicators_completed += 1
|
||||
|
||||
if not self._is_stop_requested():
|
||||
all_new_targets = new_targets.union(large_entity_members)
|
||||
for new_target in all_new_targets:
|
||||
is_ip_new = _is_valid_ip(new_target)
|
||||
eligible_providers_new = self._get_eligible_providers(new_target, is_ip_new, False)
|
||||
for p_new in eligible_providers_new:
|
||||
p_name_new = p_new.get_name()
|
||||
if (p_name_new, new_target) not in processed_tasks:
|
||||
new_depth = depth + 1 if new_target in new_targets else depth
|
||||
self.task_queue.put((self._get_priority(p_name_new), (p_name_new, new_target, new_depth)))
|
||||
finally:
|
||||
# Always remove from processing set
|
||||
with self.processing_lock:
|
||||
self.currently_processing.discard(target_item)
|
||||
|
||||
# Log termination reason
|
||||
if self._is_stop_requested():
|
||||
print("Scan terminated due to stop request")
|
||||
self.logger.logger.info("Scan terminated by user request")
|
||||
elif not self.task_queue:
|
||||
elif self.task_queue.empty():
|
||||
print("Scan completed - no more targets to process")
|
||||
self.logger.logger.info("Scan completed - all targets processed")
|
||||
|
||||
@@ -413,17 +441,16 @@ class Scanner:
|
||||
self.status = ScanStatus.FAILED
|
||||
self.logger.logger.error(f"Scan failed: {e}")
|
||||
finally:
|
||||
# Clear processing state on exit
|
||||
with self.processing_lock:
|
||||
self.currently_processing.clear()
|
||||
|
||||
|
||||
if self._is_stop_requested():
|
||||
self.status = ScanStatus.STOPPED
|
||||
elif self.scan_failed_due_to_retries:
|
||||
self.status = ScanStatus.FAILED
|
||||
else:
|
||||
self.status = ScanStatus.COMPLETED
|
||||
|
||||
|
||||
self._update_session_state()
|
||||
self.logger.log_scan_complete()
|
||||
if self.executor:
|
||||
@@ -433,60 +460,43 @@ class Scanner:
|
||||
print("Final scan statistics:")
|
||||
print(f" - Total nodes: {stats['basic_metrics']['total_nodes']}")
|
||||
print(f" - Total edges: {stats['basic_metrics']['total_edges']}")
|
||||
print(f" - Targets processed: {len(processed_targets)}")
|
||||
print(f" - Tasks processed: {len(processed_tasks)}")
|
||||
|
||||
def _query_providers_for_target(self, target: str, depth: int, dns_only: bool = False) -> Tuple[Set[str], Set[str], bool]:
|
||||
"""Query providers for a single target with enhanced stop checking."""
|
||||
# **NEW**: Early termination check
|
||||
def _query_single_provider_for_target(self, provider: BaseProvider, target: str, depth: int) -> Tuple[Set[str], Set[str], bool]:
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested before querying providers for {target}")
|
||||
print(f"Stop requested before querying {provider.get_name()} for {target}")
|
||||
return set(), set(), False
|
||||
|
||||
is_ip = _is_valid_ip(target)
|
||||
target_type = NodeType.IP if is_ip else NodeType.DOMAIN
|
||||
print(f"Querying providers for {target_type.value}: {target} at depth {depth}")
|
||||
print(f"Querying {provider.get_name()} for {target_type.value}: {target} at depth {depth}")
|
||||
|
||||
self.graph.add_node(target, target_type)
|
||||
self._initialize_provider_states(target)
|
||||
|
||||
|
||||
new_targets = set()
|
||||
large_entity_members = set()
|
||||
node_attributes = defaultdict(lambda: defaultdict(list))
|
||||
all_providers_successful = True
|
||||
provider_successful = True
|
||||
|
||||
eligible_providers = self._get_eligible_providers(target, is_ip, dns_only)
|
||||
|
||||
if not eligible_providers:
|
||||
self._log_no_eligible_providers(target, is_ip)
|
||||
return new_targets, large_entity_members, True
|
||||
|
||||
# **IMPROVED**: Check stop signal before each provider
|
||||
for i, provider in enumerate(eligible_providers):
|
||||
if self._is_stop_requested():
|
||||
print(f"Stop requested while querying provider {i+1}/{len(eligible_providers)} for {target}")
|
||||
all_providers_successful = False
|
||||
break
|
||||
|
||||
try:
|
||||
provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth)
|
||||
if provider_results is None:
|
||||
all_providers_successful = False
|
||||
elif not self._is_stop_requested():
|
||||
discovered, is_large_entity = self._process_provider_results_forensic(
|
||||
target, provider, provider_results, node_attributes, depth
|
||||
)
|
||||
if is_large_entity:
|
||||
large_entity_members.update(discovered)
|
||||
else:
|
||||
new_targets.update(discovered)
|
||||
try:
|
||||
provider_results = self._query_single_provider_forensic(provider, target, is_ip, depth)
|
||||
if provider_results is None:
|
||||
provider_successful = False
|
||||
elif not self._is_stop_requested():
|
||||
discovered, is_large_entity = self._process_provider_results_forensic(
|
||||
target, provider, provider_results, node_attributes, depth
|
||||
)
|
||||
if is_large_entity:
|
||||
large_entity_members.update(discovered)
|
||||
else:
|
||||
print(f"Stop requested after processing results from {provider.get_name()}")
|
||||
break
|
||||
except Exception as e:
|
||||
all_providers_successful = False
|
||||
self._log_provider_error(target, provider.get_name(), str(e))
|
||||
new_targets.update(discovered)
|
||||
else:
|
||||
print(f"Stop requested after processing results from {provider.get_name()}")
|
||||
except Exception as e:
|
||||
provider_successful = False
|
||||
self._log_provider_error(target, provider.get_name(), str(e))
|
||||
|
||||
# **NEW**: Only update node attributes if not stopped
|
||||
if not self._is_stop_requested():
|
||||
for node_id, attributes in node_attributes.items():
|
||||
if self.graph.graph.has_node(node_id):
|
||||
@@ -494,7 +504,7 @@ class Scanner:
|
||||
node_type_to_add = NodeType.IP if node_is_ip else NodeType.DOMAIN
|
||||
self.graph.add_node(node_id, node_type_to_add, attributes=attributes)
|
||||
|
||||
return new_targets, large_entity_members, all_providers_successful
|
||||
return new_targets, large_entity_members, provider_successful
|
||||
|
||||
def stop_scan(self) -> bool:
|
||||
"""Request immediate scan termination with proper cleanup."""
|
||||
@@ -513,8 +523,10 @@ class Scanner:
|
||||
print(f"Cleared {len(currently_processing_copy)} currently processing targets: {currently_processing_copy}")
|
||||
|
||||
# **IMPROVED**: Clear task queue and log what was discarded
|
||||
discarded_tasks = list(self.task_queue)
|
||||
self.task_queue.clear()
|
||||
discarded_tasks = []
|
||||
while not self.task_queue.empty():
|
||||
discarded_tasks.append(self.task_queue.get())
|
||||
self.task_queue = PriorityQueue()
|
||||
print(f"Discarded {len(discarded_tasks)} pending tasks")
|
||||
|
||||
# **IMPROVED**: Aggressively shut down executor
|
||||
@@ -572,7 +584,7 @@ class Scanner:
|
||||
'progress_percentage': self._calculate_progress(),
|
||||
'enabled_providers': [provider.get_name() for provider in self.providers],
|
||||
'graph_statistics': self.graph.get_statistics(),
|
||||
'task_queue_size': len(self.task_queue),
|
||||
'task_queue_size': self.task_queue.qsize(),
|
||||
'currently_processing_count': currently_processing_count, # **NEW**
|
||||
'currently_processing': currently_processing_list[:5] # **NEW**: Show first 5 for debugging
|
||||
}
|
||||
@@ -859,7 +871,7 @@ class Scanner:
|
||||
|
||||
def _calculate_progress(self) -> float:
|
||||
"""Calculate scan progress percentage based on task completion."""
|
||||
total_tasks = self.indicators_completed + len(self.task_queue)
|
||||
total_tasks = self.indicators_completed + self.task_queue.qsize()
|
||||
if total_tasks == 0:
|
||||
return 0.0
|
||||
return min(100.0, (self.indicators_completed / total_tasks) * 100)
|
||||
|
||||
Reference in New Issue
Block a user