fix race condition

This commit is contained in:
overcuriousity 2025-09-14 01:40:17 +02:00
parent 2185177a84
commit b26002eff9

View File

@ -50,6 +50,7 @@ class Scanner:
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.scan_thread = None self.scan_thread = None
self.session_id = None # Will be set by session manager self.session_id = None # Will be set by session manager
self.current_scan_id = None # NEW: Track current scan ID
# Scanning progress tracking # Scanning progress tracking
self.total_indicators_found = 0 self.total_indicators_found = 0
@ -193,24 +194,43 @@ class Scanner:
print(f"=== STARTING SCAN IN SCANNER {id(self)} ===") print(f"=== STARTING SCAN IN SCANNER {id(self)} ===")
print(f"Session ID: {self.session_id}") print(f"Session ID: {self.session_id}")
print(f"Initial scanner status: {self.status}") print(f"Initial scanner status: {self.status}")
print(f"Clear graph: {clear_graph}")
# Clean up previous scan thread if needed # Generate scan ID based on clear_graph behavior
if self.scan_thread and self.scan_thread.is_alive(): import uuid
print("A previous scan thread is still alive. Sending termination signal and waiting...")
self.stop_scan()
self.scan_thread.join(10.0)
if self.scan_thread.is_alive(): if clear_graph:
print("ERROR: The previous scan thread is unresponsive and could not be stopped.") # NEW SCAN: Generate new ID and terminate existing scan
self.status = ScanStatus.FAILED print("NEW SCAN: Generating new scan ID and terminating existing scan")
self._update_session_state() self.current_scan_id = str(uuid.uuid4())[:8]
return False
print("Previous scan thread terminated successfully.")
# Reset state for new scan # Aggressive cleanup of previous scan
self.status = ScanStatus.IDLE if self.scan_thread and self.scan_thread.is_alive():
self._update_session_state() # Update GUI immediately print("Terminating previous scan thread...")
print("Scanner state is now clean for a new scan.") self._set_stop_signal()
if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True)
self.scan_thread.join(timeout=8.0)
if self.scan_thread.is_alive():
print("WARNING: Previous scan thread did not terminate cleanly")
else:
# ADD TO GRAPH: Keep existing scan ID if scan is running, or generate new one
if self.status == ScanStatus.RUNNING and self.current_scan_id:
print(f"ADD TO GRAPH: Keeping existing scan ID {self.current_scan_id}")
# Don't terminate existing scan - we're adding to it
else:
print("ADD TO GRAPH: No active scan, generating new scan ID")
self.current_scan_id = str(uuid.uuid4())[:8]
print(f"Using scan ID: {self.current_scan_id}")
# Reset state for new scan (but preserve graph if clear_graph=False)
if clear_graph or self.status != ScanStatus.RUNNING:
self.status = ScanStatus.IDLE
self._update_session_state()
try: try:
if not hasattr(self, 'providers') or not self.providers: if not hasattr(self, 'providers') or not self.providers:
@ -221,32 +241,33 @@ class Scanner:
if clear_graph: if clear_graph:
self.graph.clear() self.graph.clear()
self.current_target = target_domain.lower().strip() self.current_target = target_domain.lower().strip()
self.max_depth = max_depth self.max_depth = max_depth
self.current_depth = 0 self.current_depth = 0
# Clear both local and Redis stop signals # Clear stop signals only if starting new scan
self.stop_event.clear() if clear_graph or self.status != ScanStatus.RUNNING:
if self.session_id: self.stop_event.clear()
from core.session_manager import session_manager if self.session_id:
session_manager.clear_stop_signal(self.session_id) from core.session_manager import session_manager
session_manager.clear_stop_signal(self.session_id)
self.total_indicators_found = 0 self.total_indicators_found = 0
self.indicators_processed = 0 self.indicators_processed = 0
self.current_indicator = self.current_target self.current_indicator = self.current_target
# Update GUI with scan preparation
self._update_session_state() self._update_session_state()
# Start new forensic session # Initialize forensic session only for new scans
print(f"Starting new forensic session for scanner {id(self)}...") if clear_graph:
self.logger = new_session() self.logger = new_session()
# Start scan in separate thread # Start scan thread (original behavior allows concurrent threads for "Add to Graph")
print(f"Starting scan thread for scanner {id(self)}...") print(f"Starting scan thread with scan ID {self.current_scan_id}...")
self.scan_thread = threading.Thread( self.scan_thread = threading.Thread(
target=self._execute_scan, target=self._execute_scan,
args=(self.current_target, max_depth), args=(self.current_target, max_depth, self.current_scan_id),
daemon=True daemon=True
) )
self.scan_thread.start() self.scan_thread.start()
@ -258,12 +279,13 @@ class Scanner:
print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}") print(f"ERROR: Exception in start_scan for scanner {id(self)}: {e}")
traceback.print_exc() traceback.print_exc()
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self._update_session_state() # Update failed status immediately self._update_session_state()
return False return False
def _execute_scan(self, target_domain: str, max_depth: int) -> None: def _execute_scan(self, target_domain: str, max_depth: int, scan_id: str) -> None:
"""Execute the reconnaissance scan using a task queue-based approach.""" """Execute the reconnaissance scan using a task queue-based approach."""
print(f"_execute_scan started for {target_domain} with depth {max_depth}") print(f"_execute_scan started for {target_domain} with depth {max_depth}, scan ID {scan_id}")
self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
processed_targets = set() processed_targets = set()
@ -279,16 +301,18 @@ class Scanner:
self._initialize_provider_states(target_domain) self._initialize_provider_states(target_domain)
while task_queue: while task_queue:
# Abort if scan ID changed (new scan started)
if self.current_scan_id != scan_id:
print(f"Scan aborted - ID mismatch (current: {self.current_scan_id}, expected: {scan_id})")
break
if self._is_stop_requested(): if self._is_stop_requested():
print("Stop requested, terminating scan.") print("Stop requested, terminating scan.")
break break
target, depth, is_large_entity_member = task_queue.popleft() target, depth, is_large_entity_member = task_queue.popleft()
if target in processed_targets: if target in processed_targets or depth > max_depth:
continue
if depth > max_depth:
continue continue
self.current_depth = depth self.current_depth = depth
@ -298,14 +322,15 @@ class Scanner:
new_targets, large_entity_members = self._query_providers_for_target(target, depth, is_large_entity_member) new_targets, large_entity_members = self._query_providers_for_target(target, depth, is_large_entity_member)
processed_targets.add(target) processed_targets.add(target)
for new_target in new_targets: # Only add new targets if scan ID still matches (prevents stale updates)
if new_target not in processed_targets: if self.current_scan_id == scan_id:
task_queue.append((new_target, depth + 1, False)) for new_target in new_targets:
if new_target not in processed_targets:
for member in large_entity_members: task_queue.append((new_target, depth + 1, False))
if member not in processed_targets:
task_queue.append((member, depth, True))
for member in large_entity_members:
if member not in processed_targets:
task_queue.append((member, depth, True))
except Exception as e: except Exception as e:
print(f"ERROR: Scan execution failed with error: {e}") print(f"ERROR: Scan execution failed with error: {e}")
@ -313,13 +338,18 @@ class Scanner:
self.status = ScanStatus.FAILED self.status = ScanStatus.FAILED
self.logger.logger.error(f"Scan failed: {e}") self.logger.logger.error(f"Scan failed: {e}")
finally: finally:
if self._is_stop_requested(): # Only update final status if scan ID still matches (prevents stale status updates)
self.status = ScanStatus.STOPPED if self.current_scan_id == scan_id:
else: if self._is_stop_requested():
self.status = ScanStatus.COMPLETED self.status = ScanStatus.STOPPED
else:
self.status = ScanStatus.COMPLETED
self._update_session_state()
self.logger.log_scan_complete()
else:
print(f"Scan completed but ID mismatch - not updating final status")
self._update_session_state()
self.logger.log_scan_complete()
if self.executor: if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True) self.executor.shutdown(wait=False, cancel_futures=True)
stats = self.graph.get_statistics() stats = self.graph.get_statistics()
@ -621,7 +651,6 @@ class Scanner:
if target not in attributes[record_type_name]: if target not in attributes[record_type_name]:
attributes[record_type_name].append(target) attributes[record_type_name].append(target)
def _log_target_processing_error(self, target: str, error: str) -> None: def _log_target_processing_error(self, target: str, error: str) -> None:
"""Log target processing errors for forensic trail.""" """Log target processing errors for forensic trail."""
self.logger.logger.error(f"Target processing failed for {target}: {error}") self.logger.logger.error(f"Target processing failed for {target}: {error}")
@ -641,7 +670,12 @@ class Scanner:
print("=== INITIATING IMMEDIATE SCAN TERMINATION ===") print("=== INITIATING IMMEDIATE SCAN TERMINATION ===")
self.logger.logger.info("Scan termination requested by user") self.logger.logger.info("Scan termination requested by user")
# Set both local and Redis stop signals # Invalidate current scan ID to prevent stale updates
old_scan_id = self.current_scan_id
self.current_scan_id = None
print(f"Invalidated scan ID {old_scan_id}")
# Set stop signals
self._set_stop_signal() self._set_stop_signal()
self.status = ScanStatus.STOPPED self.status = ScanStatus.STOPPED