564 lines
21 KiB
Python
564 lines
21 KiB
Python
# dnsrecon/core/task_manager.py
|
|
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from enum import Enum
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Any, Set
|
|
from datetime import datetime, timezone, timedelta
|
|
from collections import deque
|
|
|
|
from utils.helpers import _is_valid_ip, _is_valid_domain
|
|
|
|
|
|
class TaskStatus(Enum):
|
|
"""Enumeration of task execution statuses."""
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
SUCCEEDED = "succeeded"
|
|
FAILED_RETRYING = "failed_retrying"
|
|
FAILED_PERMANENT = "failed_permanent"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
class TaskType(Enum):
|
|
"""Enumeration of task types for provider queries."""
|
|
DOMAIN_QUERY = "domain_query"
|
|
IP_QUERY = "ip_query"
|
|
GRAPH_UPDATE = "graph_update"
|
|
|
|
|
|
@dataclass
|
|
class TaskResult:
|
|
"""Result of a task execution."""
|
|
success: bool
|
|
data: Optional[Any] = None
|
|
error: Optional[str] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class ReconTask:
|
|
"""Represents a single reconnaissance task with retry logic."""
|
|
task_id: str
|
|
task_type: TaskType
|
|
target: str
|
|
provider_name: str
|
|
depth: int
|
|
status: TaskStatus = TaskStatus.PENDING
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
# Retry configuration
|
|
max_retries: int = 3
|
|
current_retry: int = 0
|
|
base_backoff_seconds: float = 1.0
|
|
max_backoff_seconds: float = 60.0
|
|
|
|
# Execution tracking
|
|
last_attempt_at: Optional[datetime] = None
|
|
next_retry_at: Optional[datetime] = None
|
|
execution_history: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
# Results
|
|
result: Optional[TaskResult] = None
|
|
|
|
def __post_init__(self):
|
|
"""Initialize additional fields after creation."""
|
|
if not self.task_id:
|
|
self.task_id = str(uuid.uuid4())[:8]
|
|
|
|
def calculate_next_retry_time(self) -> datetime:
|
|
"""Calculate next retry time with exponential backoff and jitter."""
|
|
if self.current_retry >= self.max_retries:
|
|
return None
|
|
|
|
# Exponential backoff with jitter
|
|
backoff_time = min(
|
|
self.max_backoff_seconds,
|
|
self.base_backoff_seconds * (2 ** self.current_retry)
|
|
)
|
|
|
|
# Add jitter (±25%)
|
|
jitter = backoff_time * 0.25 * (0.5 - hash(self.task_id) % 1000 / 1000.0)
|
|
final_backoff = max(self.base_backoff_seconds, backoff_time + jitter)
|
|
|
|
return datetime.now(timezone.utc) + timedelta(seconds=final_backoff)
|
|
|
|
def should_retry(self) -> bool:
|
|
"""Determine if task should be retried based on status and retry count."""
|
|
if self.status != TaskStatus.FAILED_RETRYING:
|
|
return False
|
|
if self.current_retry >= self.max_retries:
|
|
return False
|
|
if self.next_retry_at and datetime.now(timezone.utc) < self.next_retry_at:
|
|
return False
|
|
return True
|
|
|
|
def mark_failed(self, error: str, metadata: Dict[str, Any] = None):
|
|
"""Mark task as failed and prepare for retry or permanent failure."""
|
|
self.current_retry += 1
|
|
self.last_attempt_at = datetime.now(timezone.utc)
|
|
|
|
# Record execution history
|
|
execution_record = {
|
|
'attempt': self.current_retry,
|
|
'timestamp': self.last_attempt_at.isoformat(),
|
|
'error': error,
|
|
'metadata': metadata or {}
|
|
}
|
|
self.execution_history.append(execution_record)
|
|
|
|
if self.current_retry >= self.max_retries:
|
|
self.status = TaskStatus.FAILED_PERMANENT
|
|
self.result = TaskResult(success=False, error=f"Permanent failure after {self.max_retries} attempts: {error}")
|
|
else:
|
|
self.status = TaskStatus.FAILED_RETRYING
|
|
self.next_retry_at = self.calculate_next_retry_time()
|
|
|
|
def mark_succeeded(self, data: Any = None, metadata: Dict[str, Any] = None):
|
|
"""Mark task as successfully completed."""
|
|
self.status = TaskStatus.SUCCEEDED
|
|
self.last_attempt_at = datetime.now(timezone.utc)
|
|
self.result = TaskResult(success=True, data=data, metadata=metadata or {})
|
|
|
|
# Record successful execution
|
|
execution_record = {
|
|
'attempt': self.current_retry + 1,
|
|
'timestamp': self.last_attempt_at.isoformat(),
|
|
'success': True,
|
|
'metadata': metadata or {}
|
|
}
|
|
self.execution_history.append(execution_record)
|
|
|
|
def get_summary(self) -> Dict[str, Any]:
|
|
"""Get task summary for progress reporting."""
|
|
return {
|
|
'task_id': self.task_id,
|
|
'task_type': self.task_type.value,
|
|
'target': self.target,
|
|
'provider': self.provider_name,
|
|
'status': self.status.value,
|
|
'current_retry': self.current_retry,
|
|
'max_retries': self.max_retries,
|
|
'created_at': self.created_at.isoformat(),
|
|
'last_attempt_at': self.last_attempt_at.isoformat() if self.last_attempt_at else None,
|
|
'next_retry_at': self.next_retry_at.isoformat() if self.next_retry_at else None,
|
|
'total_attempts': len(self.execution_history),
|
|
'has_result': self.result is not None
|
|
}
|
|
|
|
|
|
class TaskQueue:
|
|
"""Thread-safe task queue with retry logic and priority handling."""
|
|
|
|
def __init__(self, max_concurrent_tasks: int = 5):
|
|
"""Initialize task queue."""
|
|
self.max_concurrent_tasks = max_concurrent_tasks
|
|
self.tasks: Dict[str, ReconTask] = {}
|
|
self.pending_queue = deque()
|
|
self.retry_queue = deque()
|
|
self.running_tasks: Set[str] = set()
|
|
|
|
self._lock = threading.Lock()
|
|
self._stop_event = threading.Event()
|
|
|
|
def __getstate__(self):
|
|
"""Prepare TaskQueue for pickling by excluding unpicklable objects."""
|
|
state = self.__dict__.copy()
|
|
# Exclude the unpickleable '_lock' and '_stop_event' attributes
|
|
if '_lock' in state:
|
|
del state['_lock']
|
|
if '_stop_event' in state:
|
|
del state['_stop_event']
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""Restore TaskQueue after unpickling by reconstructing threading objects."""
|
|
self.__dict__.update(state)
|
|
# Re-initialize the '_lock' and '_stop_event' attributes
|
|
self._lock = threading.Lock()
|
|
self._stop_event = threading.Event()
|
|
|
|
def add_task(self, task: ReconTask) -> str:
|
|
"""Add task to queue."""
|
|
with self._lock:
|
|
self.tasks[task.task_id] = task
|
|
self.pending_queue.append(task.task_id)
|
|
print(f"Added task {task.task_id}: {task.provider_name} query for {task.target}")
|
|
return task.task_id
|
|
|
|
def get_next_ready_task(self) -> Optional[ReconTask]:
|
|
"""Get next task ready for execution."""
|
|
with self._lock:
|
|
# Check if we have room for more concurrent tasks
|
|
if len(self.running_tasks) >= self.max_concurrent_tasks:
|
|
return None
|
|
|
|
# First priority: retry queue (tasks ready for retry)
|
|
while self.retry_queue:
|
|
task_id = self.retry_queue.popleft()
|
|
if task_id in self.tasks:
|
|
task = self.tasks[task_id]
|
|
if task.should_retry():
|
|
task.status = TaskStatus.RUNNING
|
|
self.running_tasks.add(task_id)
|
|
print(f"Retrying task {task_id} (attempt {task.current_retry + 1})")
|
|
return task
|
|
|
|
# Second priority: pending queue (new tasks)
|
|
while self.pending_queue:
|
|
task_id = self.pending_queue.popleft()
|
|
if task_id in self.tasks:
|
|
task = self.tasks[task_id]
|
|
if task.status == TaskStatus.PENDING:
|
|
task.status = TaskStatus.RUNNING
|
|
self.running_tasks.add(task_id)
|
|
print(f"Starting task {task_id}")
|
|
return task
|
|
|
|
return None
|
|
|
|
def complete_task(self, task_id: str, success: bool, data: Any = None,
|
|
error: str = None, metadata: Dict[str, Any] = None):
|
|
"""Mark task as completed (success or failure)."""
|
|
with self._lock:
|
|
if task_id not in self.tasks:
|
|
return
|
|
|
|
task = self.tasks[task_id]
|
|
self.running_tasks.discard(task_id)
|
|
|
|
if success:
|
|
task.mark_succeeded(data=data, metadata=metadata)
|
|
print(f"Task {task_id} succeeded")
|
|
else:
|
|
task.mark_failed(error or "Unknown error", metadata=metadata)
|
|
if task.status == TaskStatus.FAILED_RETRYING:
|
|
self.retry_queue.append(task_id)
|
|
print(f"Task {task_id} failed, scheduled for retry at {task.next_retry_at}")
|
|
else:
|
|
print(f"Task {task_id} permanently failed after {task.current_retry} attempts")
|
|
|
|
def cancel_all_tasks(self):
|
|
"""Cancel all pending and running tasks."""
|
|
with self._lock:
|
|
self._stop_event.set()
|
|
for task in self.tasks.values():
|
|
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.FAILED_RETRYING]:
|
|
task.status = TaskStatus.CANCELLED
|
|
self.pending_queue.clear()
|
|
self.retry_queue.clear()
|
|
self.running_tasks.clear()
|
|
print("All tasks cancelled")
|
|
|
|
def is_complete(self) -> bool:
|
|
"""Check if all tasks are complete (succeeded, permanently failed, or cancelled)."""
|
|
with self._lock:
|
|
for task in self.tasks.values():
|
|
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING, TaskStatus.FAILED_RETRYING]:
|
|
return False
|
|
return True
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get queue statistics."""
|
|
with self._lock:
|
|
stats = {
|
|
'total_tasks': len(self.tasks),
|
|
'pending': len(self.pending_queue),
|
|
'running': len(self.running_tasks),
|
|
'retry_queue': len(self.retry_queue),
|
|
'succeeded': 0,
|
|
'failed_permanent': 0,
|
|
'cancelled': 0,
|
|
'failed_retrying': 0
|
|
}
|
|
|
|
for task in self.tasks.values():
|
|
if task.status == TaskStatus.SUCCEEDED:
|
|
stats['succeeded'] += 1
|
|
elif task.status == TaskStatus.FAILED_PERMANENT:
|
|
stats['failed_permanent'] += 1
|
|
elif task.status == TaskStatus.CANCELLED:
|
|
stats['cancelled'] += 1
|
|
elif task.status == TaskStatus.FAILED_RETRYING:
|
|
stats['failed_retrying'] += 1
|
|
|
|
stats['completion_rate'] = (stats['succeeded'] / stats['total_tasks'] * 100) if stats['total_tasks'] > 0 else 0
|
|
stats['is_complete'] = self.is_complete()
|
|
|
|
return stats
|
|
|
|
def get_task_summaries(self) -> List[Dict[str, Any]]:
|
|
"""Get summaries of all tasks for detailed progress reporting."""
|
|
with self._lock:
|
|
return [task.get_summary() for task in self.tasks.values()]
|
|
|
|
def get_failed_tasks(self) -> List[ReconTask]:
|
|
"""Get all permanently failed tasks for analysis."""
|
|
with self._lock:
|
|
return [task for task in self.tasks.values() if task.status == TaskStatus.FAILED_PERMANENT]
|
|
|
|
|
|
class TaskExecutor:
|
|
"""Executes reconnaissance tasks using providers."""
|
|
|
|
def __init__(self, providers: List, graph_manager, logger):
|
|
"""Initialize task executor."""
|
|
self.providers = {provider.get_name(): provider for provider in providers}
|
|
self.graph = graph_manager
|
|
self.logger = logger
|
|
|
|
def execute_task(self, task: ReconTask) -> TaskResult:
|
|
"""
|
|
Execute a single reconnaissance task.
|
|
|
|
Args:
|
|
task: Task to execute
|
|
|
|
Returns:
|
|
TaskResult with success/failure information
|
|
"""
|
|
try:
|
|
print(f"Executing task {task.task_id}: {task.provider_name} query for {task.target}")
|
|
|
|
provider = self.providers.get(task.provider_name)
|
|
if not provider:
|
|
return TaskResult(
|
|
success=False,
|
|
error=f"Provider {task.provider_name} not available"
|
|
)
|
|
|
|
if not provider.is_available():
|
|
return TaskResult(
|
|
success=False,
|
|
error=f"Provider {task.provider_name} is not available (missing API key or configuration)"
|
|
)
|
|
|
|
# Execute provider query based on task type
|
|
if task.task_type == TaskType.DOMAIN_QUERY:
|
|
if not _is_valid_domain(task.target):
|
|
return TaskResult(success=False, error=f"Invalid domain: {task.target}")
|
|
|
|
relationships = provider.query_domain(task.target)
|
|
|
|
elif task.task_type == TaskType.IP_QUERY:
|
|
if not _is_valid_ip(task.target):
|
|
return TaskResult(success=False, error=f"Invalid IP: {task.target}")
|
|
|
|
relationships = provider.query_ip(task.target)
|
|
|
|
else:
|
|
return TaskResult(success=False, error=f"Unsupported task type: {task.task_type}")
|
|
|
|
# Process results and update graph
|
|
new_targets = set()
|
|
relationships_added = 0
|
|
|
|
for source, target, rel_type, confidence, raw_data in relationships:
|
|
# Add nodes to graph
|
|
from core.graph_manager import NodeType
|
|
|
|
if _is_valid_ip(target):
|
|
self.graph.add_node(target, NodeType.IP)
|
|
new_targets.add(target)
|
|
elif target.startswith('AS') and target[2:].isdigit():
|
|
self.graph.add_node(target, NodeType.ASN)
|
|
elif _is_valid_domain(target):
|
|
self.graph.add_node(target, NodeType.DOMAIN)
|
|
new_targets.add(target)
|
|
|
|
# Add edge to graph
|
|
if self.graph.add_edge(source, target, rel_type, confidence, task.provider_name, raw_data):
|
|
relationships_added += 1
|
|
|
|
# Log forensic information
|
|
self.logger.logger.info(
|
|
f"Task {task.task_id} completed: {len(relationships)} relationships found, "
|
|
f"{relationships_added} added to graph, {len(new_targets)} new targets"
|
|
)
|
|
|
|
return TaskResult(
|
|
success=True,
|
|
data={
|
|
'relationships': relationships,
|
|
'new_targets': list(new_targets),
|
|
'relationships_added': relationships_added
|
|
},
|
|
metadata={
|
|
'provider': task.provider_name,
|
|
'target': task.target,
|
|
'depth': task.depth,
|
|
'execution_time': datetime.now(timezone.utc).isoformat()
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Task execution failed: {str(e)}"
|
|
print(f"ERROR: {error_msg} for task {task.task_id}")
|
|
self.logger.logger.error(error_msg)
|
|
|
|
return TaskResult(
|
|
success=False,
|
|
error=error_msg,
|
|
metadata={
|
|
'provider': task.provider_name,
|
|
'target': task.target,
|
|
'exception_type': type(e).__name__
|
|
}
|
|
)
|
|
|
|
|
|
class TaskManager:
|
|
"""High-level task management for reconnaissance scans."""
|
|
|
|
def __init__(self, providers: List, graph_manager, logger, max_concurrent_tasks: int = 5):
|
|
"""Initialize task manager."""
|
|
self.task_queue = TaskQueue(max_concurrent_tasks)
|
|
self.task_executor = TaskExecutor(providers, graph_manager, logger)
|
|
self.logger = logger
|
|
|
|
# Execution control
|
|
self._stop_event = threading.Event()
|
|
self._execution_threads: List[threading.Thread] = []
|
|
self._is_running = False
|
|
|
|
def create_provider_tasks(self, target: str, depth: int, providers: List) -> List[str]:
|
|
"""
|
|
Create tasks for querying all eligible providers for a target.
|
|
|
|
Args:
|
|
target: Domain or IP to query
|
|
depth: Current recursion depth
|
|
providers: List of available providers
|
|
|
|
Returns:
|
|
List of created task IDs
|
|
"""
|
|
task_ids = []
|
|
is_ip = _is_valid_ip(target)
|
|
target_key = 'ips' if is_ip else 'domains'
|
|
task_type = TaskType.IP_QUERY if is_ip else TaskType.DOMAIN_QUERY
|
|
|
|
for provider in providers:
|
|
if provider.get_eligibility().get(target_key) and provider.is_available():
|
|
task = ReconTask(
|
|
task_id=str(uuid.uuid4())[:8],
|
|
task_type=task_type,
|
|
target=target,
|
|
provider_name=provider.get_name(),
|
|
depth=depth,
|
|
max_retries=3 # Configure retries per task type/provider
|
|
)
|
|
|
|
task_id = self.task_queue.add_task(task)
|
|
task_ids.append(task_id)
|
|
|
|
return task_ids
|
|
|
|
def start_execution(self, max_workers: int = 3):
|
|
"""Start task execution with specified number of worker threads."""
|
|
if self._is_running:
|
|
print("Task execution already running")
|
|
return
|
|
|
|
self._is_running = True
|
|
self._stop_event.clear()
|
|
|
|
print(f"Starting task execution with {max_workers} workers")
|
|
|
|
for i in range(max_workers):
|
|
worker_thread = threading.Thread(
|
|
target=self._worker_loop,
|
|
name=f"TaskWorker-{i+1}",
|
|
daemon=True
|
|
)
|
|
worker_thread.start()
|
|
self._execution_threads.append(worker_thread)
|
|
|
|
def stop_execution(self):
|
|
"""Stop task execution and cancel all tasks."""
|
|
print("Stopping task execution")
|
|
self._stop_event.set()
|
|
self.task_queue.cancel_all_tasks()
|
|
self._is_running = False
|
|
|
|
# Wait for worker threads to finish
|
|
for thread in self._execution_threads:
|
|
thread.join(timeout=5.0)
|
|
|
|
self._execution_threads.clear()
|
|
print("Task execution stopped")
|
|
|
|
def _worker_loop(self):
|
|
"""Worker thread loop for executing tasks."""
|
|
thread_name = threading.current_thread().name
|
|
print(f"{thread_name} started")
|
|
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
# Get next task to execute
|
|
task = self.task_queue.get_next_ready_task()
|
|
|
|
if task is None:
|
|
# No tasks ready, check if we should exit
|
|
if self.task_queue.is_complete() or self._stop_event.is_set():
|
|
break
|
|
time.sleep(0.1) # Brief sleep before checking again
|
|
continue
|
|
|
|
# Execute the task
|
|
result = self.task_executor.execute_task(task)
|
|
|
|
# Complete the task in queue
|
|
self.task_queue.complete_task(
|
|
task.task_id,
|
|
success=result.success,
|
|
data=result.data,
|
|
error=result.error,
|
|
metadata=result.metadata
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Worker {thread_name} encountered error: {e}")
|
|
# Continue running even if individual task fails
|
|
continue
|
|
|
|
print(f"{thread_name} finished")
|
|
|
|
def wait_for_completion(self, timeout_seconds: int = 300) -> bool:
|
|
"""
|
|
Wait for all tasks to complete.
|
|
|
|
Args:
|
|
timeout_seconds: Maximum time to wait
|
|
|
|
Returns:
|
|
True if all tasks completed, False if timeout
|
|
"""
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout_seconds:
|
|
if self.task_queue.is_complete():
|
|
return True
|
|
|
|
if self._stop_event.is_set():
|
|
return False
|
|
|
|
time.sleep(1.0) # Check every second
|
|
|
|
print(f"Timeout waiting for task completion after {timeout_seconds} seconds")
|
|
return False
|
|
|
|
def get_progress_report(self) -> Dict[str, Any]:
|
|
"""Get detailed progress report for UI updates."""
|
|
stats = self.task_queue.get_statistics()
|
|
failed_tasks = self.task_queue.get_failed_tasks()
|
|
|
|
return {
|
|
'statistics': stats,
|
|
'failed_tasks': [task.get_summary() for task in failed_tasks],
|
|
'is_running': self._is_running,
|
|
'worker_count': len(self._execution_threads),
|
|
'detailed_tasks': self.task_queue.get_task_summaries() if stats['total_tasks'] < 50 else [] # Limit detail for performance
|
|
} |