# dnsrecon/core/task_manager.py import threading import time import uuid from enum import Enum from dataclasses import dataclass, field from typing import Dict, List, Optional, Any, Set from datetime import datetime, timezone, timedelta from collections import deque from utils.helpers import _is_valid_ip, _is_valid_domain class TaskStatus(Enum): """Enumeration of task execution statuses.""" PENDING = "pending" RUNNING = "running" SUCCEEDED = "succeeded" FAILED_RETRYING = "failed_retrying" FAILED_PERMANENT = "failed_permanent" CANCELLED = "cancelled" class TaskType(Enum): """Enumeration of task types for provider queries.""" DOMAIN_QUERY = "domain_query" IP_QUERY = "ip_query" GRAPH_UPDATE = "graph_update" @dataclass class TaskResult: """Result of a task execution.""" success: bool data: Optional[Any] = None error: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) @dataclass class ReconTask: """Represents a single reconnaissance task with retry logic.""" task_id: str task_type: TaskType target: str provider_name: str depth: int status: TaskStatus = TaskStatus.PENDING created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) # Retry configuration max_retries: int = 3 current_retry: int = 0 base_backoff_seconds: float = 1.0 max_backoff_seconds: float = 60.0 # Execution tracking last_attempt_at: Optional[datetime] = None next_retry_at: Optional[datetime] = None execution_history: List[Dict[str, Any]] = field(default_factory=list) # Results result: Optional[TaskResult] = None def __post_init__(self): """Initialize additional fields after creation.""" if not self.task_id: self.task_id = str(uuid.uuid4())[:8] def calculate_next_retry_time(self) -> datetime: """Calculate next retry time with exponential backoff and jitter.""" if self.current_retry >= self.max_retries: return None # Exponential backoff with jitter backoff_time = min( self.max_backoff_seconds, self.base_backoff_seconds * (2 ** self.current_retry) ) # Add jitter (±25%) jitter = backoff_time * 0.25 * (0.5 - hash(self.task_id) % 1000 / 1000.0) final_backoff = max(self.base_backoff_seconds, backoff_time + jitter) return datetime.now(timezone.utc) + timedelta(seconds=final_backoff) def should_retry(self) -> bool: """Determine if task should be retried based on status and retry count.""" if self.status != TaskStatus.FAILED_RETRYING: return False if self.current_retry >= self.max_retries: return False if self.next_retry_at and datetime.now(timezone.utc) < self.next_retry_at: return False return True def mark_failed(self, error: str, metadata: Dict[str, Any] = None): """Mark task as failed and prepare for retry or permanent failure.""" self.current_retry += 1 self.last_attempt_at = datetime.now(timezone.utc) # Record execution history execution_record = { 'attempt': self.current_retry, 'timestamp': self.last_attempt_at.isoformat(), 'error': error, 'metadata': metadata or {} } self.execution_history.append(execution_record) if self.current_retry >= self.max_retries: self.status = TaskStatus.FAILED_PERMANENT self.result = TaskResult(success=False, error=f"Permanent failure after {self.max_retries} attempts: {error}") else: self.status = TaskStatus.FAILED_RETRYING self.next_retry_at = self.calculate_next_retry_time() def mark_succeeded(self, data: Any = None, metadata: Dict[str, Any] = None): """Mark task as successfully completed.""" self.status = TaskStatus.SUCCEEDED self.last_attempt_at = datetime.now(timezone.utc) self.result = TaskResult(success=True, data=data, metadata=metadata or {}) # Record successful execution execution_record = { 'attempt': self.current_retry + 1, 'timestamp': self.last_attempt_at.isoformat(), 'success': True, 'metadata': metadata or {} } self.execution_history.append(execution_record) def get_summary(self) -> Dict[str, Any]: """Get task summary for progress reporting.""" return { 'task_id': self.task_id, 'task_type': self.task_type.value, 'target': self.target, 'provider': self.provider_name, 'status': self.status.value, 'current_retry': self.current_retry, 'max_retries': self.max_retries, 'created_at': self.created_at.isoformat(), 'last_attempt_at': self.last_attempt_at.isoformat() if self.last_attempt_at else None, 'next_retry_at': self.next_retry_at.isoformat() if self.next_retry_at else None, 'total_attempts': len(self.execution_history), 'has_result': self.result is not None } class TaskQueue: """Thread-safe task queue with retry logic and priority handling.""" def __init__(self, max_concurrent_tasks: int = 5): """Initialize task queue.""" self.max_concurrent_tasks = max_concurrent_tasks self.tasks: Dict[str, ReconTask] = {} self.pending_queue = deque() self.retry_queue = deque() self.running_tasks: Set[str] = set() self._lock = threading.Lock() self._stop_event = threading.Event() def __getstate__(self): """Prepare TaskQueue for pickling by excluding unpicklable objects.""" state = self.__dict__.copy() # Exclude the unpickleable '_lock' and '_stop_event' attributes if '_lock' in state: del state['_lock'] if '_stop_event' in state: del state['_stop_event'] return state def __setstate__(self, state): """Restore TaskQueue after unpickling by reconstructing threading objects.""" self.__dict__.update(state) # Re-initialize the '_lock' and '_stop_event' attributes self._lock = threading.Lock() self._stop_event = threading.Event() def add_task(self, task: ReconTask) -> str: """Add task to queue.""" with self._lock: self.tasks[task.task_id] = task self.pending_queue.append(task.task_id) print(f"Added task {task.task_id}: {task.provider_name} query for {task.target}") return task.task_id def get_next_ready_task(self) -> Optional[ReconTask]: """Get next task ready for execution.""" with self._lock: # Check if we have room for more concurrent tasks if len(self.running_tasks) >= self.max_concurrent_tasks: return None # First priority: retry queue (tasks ready for retry) while self.retry_queue: task_id = self.retry_queue.popleft() if task_id in self.tasks: task = self.tasks[task_id] if task.should_retry(): task.status = TaskStatus.RUNNING self.running_tasks.add(task_id) print(f"Retrying task {task_id} (attempt {task.current_retry + 1})") return task # Second priority: pending queue (new tasks) while self.pending_queue: task_id = self.pending_queue.popleft() if task_id in self.tasks: task = self.tasks[task_id] if task.status == TaskStatus.PENDING: task.status = TaskStatus.RUNNING self.running_tasks.add(task_id) print(f"Starting task {task_id}") return task return None def complete_task(self, task_id: str, success: bool, data: Any = None, error: str = None, metadata: Dict[str, Any] = None): """Mark task as completed (success or failure).""" with self._lock: if task_id not in self.tasks: return task = self.tasks[task_id] self.running_tasks.discard(task_id) if success: task.mark_succeeded(data=data, metadata=metadata) print(f"Task {task_id} succeeded") else: task.mark_failed(error or "Unknown error", metadata=metadata) if task.status == TaskStatus.FAILED_RETRYING: self.retry_queue.append(task_id) print(f"Task {task_id} failed, scheduled for retry at {task.next_retry_at}") else: print(f"Task {task_id} permanently failed after {task.current_retry} attempts") def cancel_all_tasks(self): """Cancel all pending and running tasks.""" with self._lock: self._stop_event.set() for task in self.tasks.values(): if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.FAILED_RETRYING]: task.status = TaskStatus.CANCELLED self.pending_queue.clear() self.retry_queue.clear() self.running_tasks.clear() print("All tasks cancelled") def is_complete(self) -> bool: """Check if all tasks are complete (succeeded, permanently failed, or cancelled).""" with self._lock: for task in self.tasks.values(): if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.FAILED_RETRYING]: return False return True def get_statistics(self) -> Dict[str, Any]: """Get queue statistics.""" with self._lock: stats = { 'total_tasks': len(self.tasks), 'pending': len(self.pending_queue), 'running': len(self.running_tasks), 'retry_queue': len(self.retry_queue), 'succeeded': 0, 'failed_permanent': 0, 'cancelled': 0, 'failed_retrying': 0 } for task in self.tasks.values(): if task.status == TaskStatus.SUCCEEDED: stats['succeeded'] += 1 elif task.status == TaskStatus.FAILED_PERMANENT: stats['failed_permanent'] += 1 elif task.status == TaskStatus.CANCELLED: stats['cancelled'] += 1 elif task.status == TaskStatus.FAILED_RETRYING: stats['failed_retrying'] += 1 stats['completion_rate'] = (stats['succeeded'] / stats['total_tasks'] * 100) if stats['total_tasks'] > 0 else 0 stats['is_complete'] = self.is_complete() return stats def get_task_summaries(self) -> List[Dict[str, Any]]: """Get summaries of all tasks for detailed progress reporting.""" with self._lock: return [task.get_summary() for task in self.tasks.values()] def get_failed_tasks(self) -> List[ReconTask]: """Get all permanently failed tasks for analysis.""" with self._lock: return [task for task in self.tasks.values() if task.status == TaskStatus.FAILED_PERMANENT] class TaskExecutor: """Executes reconnaissance tasks using providers.""" def __init__(self, providers: List, graph_manager, logger): """Initialize task executor.""" self.providers = {provider.get_name(): provider for provider in providers} self.graph = graph_manager self.logger = logger def execute_task(self, task: ReconTask) -> TaskResult: """ Execute a single reconnaissance task. Args: task: Task to execute Returns: TaskResult with success/failure information """ try: print(f"Executing task {task.task_id}: {task.provider_name} query for {task.target}") provider = self.providers.get(task.provider_name) if not provider: return TaskResult( success=False, error=f"Provider {task.provider_name} not available" ) if not provider.is_available(): return TaskResult( success=False, error=f"Provider {task.provider_name} is not available (missing API key or configuration)" ) # Execute provider query based on task type if task.task_type == TaskType.DOMAIN_QUERY: if not _is_valid_domain(task.target): return TaskResult(success=False, error=f"Invalid domain: {task.target}") relationships = provider.query_domain(task.target) elif task.task_type == TaskType.IP_QUERY: if not _is_valid_ip(task.target): return TaskResult(success=False, error=f"Invalid IP: {task.target}") relationships = provider.query_ip(task.target) else: return TaskResult(success=False, error=f"Unsupported task type: {task.task_type}") # Process results and update graph new_targets = set() relationships_added = 0 for source, target, rel_type, confidence, raw_data in relationships: # Add nodes to graph from core.graph_manager import NodeType if _is_valid_ip(target): self.graph.add_node(target, NodeType.IP) new_targets.add(target) elif target.startswith('AS') and target[2:].isdigit(): self.graph.add_node(target, NodeType.ASN) elif _is_valid_domain(target): self.graph.add_node(target, NodeType.DOMAIN) new_targets.add(target) # Add edge to graph if self.graph.add_edge(source, target, rel_type, confidence, task.provider_name, raw_data): relationships_added += 1 # Log forensic information self.logger.logger.info( f"Task {task.task_id} completed: {len(relationships)} relationships found, " f"{relationships_added} added to graph, {len(new_targets)} new targets" ) return TaskResult( success=True, data={ 'relationships': relationships, 'new_targets': list(new_targets), 'relationships_added': relationships_added }, metadata={ 'provider': task.provider_name, 'target': task.target, 'depth': task.depth, 'execution_time': datetime.now(timezone.utc).isoformat() } ) except Exception as e: error_msg = f"Task execution failed: {str(e)}" print(f"ERROR: {error_msg} for task {task.task_id}") self.logger.logger.error(error_msg) return TaskResult( success=False, error=error_msg, metadata={ 'provider': task.provider_name, 'target': task.target, 'exception_type': type(e).__name__ } ) class TaskManager: """High-level task management for reconnaissance scans.""" def __init__(self, providers: List, graph_manager, logger, max_concurrent_tasks: int = 5): """Initialize task manager.""" self.task_queue = TaskQueue(max_concurrent_tasks) self.task_executor = TaskExecutor(providers, graph_manager, logger) self.logger = logger # Execution control self._stop_event = threading.Event() self._execution_threads: List[threading.Thread] = [] self._is_running = False def create_provider_tasks(self, target: str, depth: int, providers: List) -> List[str]: """ Create tasks for querying all eligible providers for a target. Args: target: Domain or IP to query depth: Current recursion depth providers: List of available providers Returns: List of created task IDs """ task_ids = [] is_ip = _is_valid_ip(target) target_key = 'ips' if is_ip else 'domains' task_type = TaskType.IP_QUERY if is_ip else TaskType.DOMAIN_QUERY for provider in providers: if provider.get_eligibility().get(target_key) and provider.is_available(): task = ReconTask( task_id=str(uuid.uuid4())[:8], task_type=task_type, target=target, provider_name=provider.get_name(), depth=depth, max_retries=3 # Configure retries per task type/provider ) task_id = self.task_queue.add_task(task) task_ids.append(task_id) return task_ids def start_execution(self, max_workers: int = 3): """Start task execution with specified number of worker threads.""" if self._is_running: print("Task execution already running") return self._is_running = True self._stop_event.clear() print(f"Starting task execution with {max_workers} workers") for i in range(max_workers): worker_thread = threading.Thread( target=self._worker_loop, name=f"TaskWorker-{i+1}", daemon=True ) worker_thread.start() self._execution_threads.append(worker_thread) def stop_execution(self): """Stop task execution and cancel all tasks.""" print("Stopping task execution") self._stop_event.set() self.task_queue.cancel_all_tasks() self._is_running = False # Wait for worker threads to finish for thread in self._execution_threads: thread.join(timeout=5.0) self._execution_threads.clear() print("Task execution stopped") def _worker_loop(self): """Worker thread loop for executing tasks.""" thread_name = threading.current_thread().name print(f"{thread_name} started") while not self._stop_event.is_set(): try: # Get next task to execute task = self.task_queue.get_next_ready_task() if task is None: # No tasks ready, check if we should exit if self.task_queue.is_complete() or self._stop_event.is_set(): break time.sleep(0.1) # Brief sleep before checking again continue # Execute the task result = self.task_executor.execute_task(task) # Complete the task in queue self.task_queue.complete_task( task.task_id, success=result.success, data=result.data, error=result.error, metadata=result.metadata ) except Exception as e: print(f"ERROR: Worker {thread_name} encountered error: {e}") # Continue running even if individual task fails continue print(f"{thread_name} finished") def wait_for_completion(self, timeout_seconds: int = 300) -> bool: """ Wait for all tasks to complete. Args: timeout_seconds: Maximum time to wait Returns: True if all tasks completed, False if timeout """ start_time = time.time() while time.time() - start_time < timeout_seconds: if self.task_queue.is_complete(): return True if self._stop_event.is_set(): return False time.sleep(1.0) # Check every second print(f"Timeout waiting for task completion after {timeout_seconds} seconds") return False def get_progress_report(self) -> Dict[str, Any]: """Get detailed progress report for UI updates.""" stats = self.task_queue.get_statistics() failed_tasks = self.task_queue.get_failed_tasks() return { 'statistics': stats, 'failed_tasks': [task.get_summary() for task in failed_tasks], 'is_running': self._is_running, 'worker_count': len(self._execution_threads), 'detailed_tasks': self.task_queue.get_task_summaries() if stats['total_tasks'] < 50 else [] # Limit detail for performance }