Fix critical bugs and improve data integrity across codebase

This commit addresses 20 bugs discovered during comprehensive code review,
focusing on data integrity, concurrent access, and user experience.

CRITICAL FIXES:
- Fix GPG key listing to support keys with multiple UIDs (crypto.py:40)
- Implement cross-platform file locking to prevent concurrent access corruption (storage.py)
- Fix evidence detail delete logic that could delete wrong note (tui.py:2481-2497)
- Add corrupted JSON handling with user prompt and automatic backup (storage.py, tui.py)

DATA INTEGRITY:
- Fix IOC/Hash pattern false positives by checking longest hashes first (models.py:32-95)
- Fix URL pattern to exclude trailing punctuation (models.py:81, 152, 216)
- Improve IOC overlap detection with proper range tracking (models.py)
- Fix note deletion to use note_id instead of object identity (tui.py:2498-2619)
- Add state validation to detect and clear orphaned references (storage.py:355-384)

SCROLLING & NAVIGATION:
- Fix evidence detail view to support full scrolling instead of "last N" (tui.py:816-847)
- Fix filter reset index bounds bug (tui.py:1581-1654)
- Add scroll_offset validation after all operations (tui.py:1608-1654)
- Fix division by zero in scroll calculations (tui.py:446-478)
- Validate selection bounds across all views (tui.py:_validate_selection_bounds)

EXPORT & CLI:
- Fix multi-line note export with proper markdown indentation (cli.py:129-143)
- Add stderr warnings for GPG signature failures (cli.py:61, 63)
- Validate active context and show warnings in CLI (cli.py:12-44)

TESTING:
- Update tests to support new lock file mechanism (test_models.py)
- All existing tests pass with new changes

Breaking changes: None
Backward compatible: Yes (existing data files work unchanged)
This commit is contained in:
Claude
2025-12-13 16:16:54 +00:00
parent ba7a8fdd5d
commit 2453bd4f2a
6 changed files with 547 additions and 178 deletions

View File

@@ -8,6 +8,12 @@ from .crypto import Crypto
def quick_add_note(content: str): def quick_add_note(content: str):
storage = Storage() storage = Storage()
state_manager = StateManager() state_manager = StateManager()
# Validate and clear stale state
warning = state_manager.validate_and_clear_stale(storage)
if warning:
print(f"Warning: {warning}", file=sys.stderr)
state = state_manager.get_active() state = state_manager.get_active()
settings = state_manager.get_settings() settings = state_manager.get_settings()
@@ -15,23 +21,28 @@ def quick_add_note(content: str):
evidence_id = state.get("evidence_id") evidence_id = state.get("evidence_id")
if not case_id: if not case_id:
print("Error: No active case set. Open the TUI to select a case first.") print("Error: No active case set. Open the TUI to select a case first.", file=sys.stderr)
sys.exit(1) sys.exit(1)
case = storage.get_case(case_id) case = storage.get_case(case_id)
if not case: if not case:
print("Error: Active case not found in storage. Ensure you have set an active case in the TUI.") print("Error: Active case not found in storage. Ensure you have set an active case in the TUI.", file=sys.stderr)
sys.exit(1) sys.exit(1)
target_evidence = None target_evidence = None
if evidence_id: if evidence_id:
# Find evidence # Find and validate evidence belongs to active case
for ev in case.evidence: for ev in case.evidence:
if ev.evidence_id == evidence_id: if ev.evidence_id == evidence_id:
target_evidence = ev target_evidence = ev
break break
if not target_evidence:
# Evidence ID is set but doesn't exist in case - clear it
print(f"Warning: Active evidence not found in case. Clearing to case level.", file=sys.stderr)
state_manager.set_active(case_id, None)
# Create note # Create note
note = Note(content=content) note = Note(content=content)
note.calculate_hash() note.calculate_hash()
@@ -47,9 +58,9 @@ def quick_add_note(content: str):
if signature: if signature:
note.signature = signature note.signature = signature
else: else:
print("Warning: GPG signature failed (GPG not found or no key). Note saved without signature.") print("Warning: GPG signature failed (GPG not found or no key). Note saved without signature.", file=sys.stderr)
else: else:
print("Warning: No GPG key ID configured. Note saved without signature.") print("Warning: No GPG key ID configured. Note saved without signature.", file=sys.stderr)
# Attach to evidence or case # Attach to evidence or case
if target_evidence: if target_evidence:
@@ -117,7 +128,10 @@ def export_markdown(output_file: str = "export.md"):
def write_note(f, note: Note): def write_note(f, note: Note):
f.write(f"- **{time.ctime(note.timestamp)}**\n") f.write(f"- **{time.ctime(note.timestamp)}**\n")
f.write(f" - Content: {note.content}\n") f.write(f" - Content:\n")
# Properly indent multi-line content
for line in note.content.splitlines():
f.write(f" {line}\n")
f.write(f" - Hash: `{note.content_hash}`\n") f.write(f" - Hash: `{note.content_hash}`\n")
if note.signature: if note.signature:
f.write(" - **Signature Verified:**\n") f.write(" - **Signature Verified:**\n")

View File

@@ -37,7 +37,7 @@ class Crypto:
elif fields[0] == 'uid' and current_key_id: elif fields[0] == 'uid' and current_key_id:
user_id = fields[9] if len(fields) > 9 else "Unknown" user_id = fields[9] if len(fields) > 9 else "Unknown"
keys.append((current_key_id, user_id)) keys.append((current_key_id, user_id))
current_key_id = None # Reset after matching # Don't reset current_key_id - allow multiple UIDs per key
return keys return keys

View File

@@ -32,64 +32,67 @@ class Note:
def extract_iocs(self): def extract_iocs(self):
"""Extract Indicators of Compromise from content""" """Extract Indicators of Compromise from content"""
seen = set() seen = set()
covered_ranges = set()
self.iocs = [] self.iocs = []
# IPv4 addresses def add_ioc_if_not_covered(match_obj):
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b' """Add IOC if its range doesn't overlap with already covered ranges"""
for match in re.findall(ipv4_pattern, self.content): start, end = match_obj.start(), match_obj.end()
if match not in seen: # Check if this range overlaps with any covered range
seen.add(match) for covered_start, covered_end in covered_ranges:
self.iocs.append(match) if not (end <= covered_start or start >= covered_end):
return False # Overlaps, don't add
text = match_obj.group()
if text not in seen:
seen.add(text)
covered_ranges.add((start, end))
self.iocs.append(text)
return True
return False
# IPv6 addresses (supports compressed format) # Process in order of priority to avoid false positives
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b' # SHA256 hashes (64 hex chars) - check longest first to avoid substring matches
for match in re.findall(ipv6_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
# URLs (check before domains to prevent double-matching)
url_pattern = r'https?://[^\s]+'
for match in re.findall(url_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.findall(domain_pattern, self.content):
# Filter out common false positives
if match not in seen and not match.startswith('example.'):
seen.add(match)
self.iocs.append(match)
# SHA256 hashes (64 hex chars) - check longest first
sha256_pattern = r'\b[a-fA-F0-9]{64}\b' sha256_pattern = r'\b[a-fA-F0-9]{64}\b'
for match in re.findall(sha256_pattern, self.content): for match in re.finditer(sha256_pattern, self.content):
if match not in seen: add_ioc_if_not_covered(match)
seen.add(match)
self.iocs.append(match)
# SHA1 hashes (40 hex chars) # SHA1 hashes (40 hex chars)
sha1_pattern = r'\b[a-fA-F0-9]{40}\b' sha1_pattern = r'\b[a-fA-F0-9]{40}\b'
for match in re.findall(sha1_pattern, self.content): for match in re.finditer(sha1_pattern, self.content):
if match not in seen: add_ioc_if_not_covered(match)
seen.add(match)
self.iocs.append(match)
# MD5 hashes (32 hex chars) # MD5 hashes (32 hex chars)
md5_pattern = r'\b[a-fA-F0-9]{32}\b' md5_pattern = r'\b[a-fA-F0-9]{32}\b'
for match in re.findall(md5_pattern, self.content): for match in re.finditer(md5_pattern, self.content):
if match not in seen: add_ioc_if_not_covered(match)
seen.add(match)
self.iocs.append(match) # IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.finditer(ipv4_pattern, self.content):
add_ioc_if_not_covered(match)
# IPv6 addresses (supports compressed format)
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b'
for match in re.finditer(ipv6_pattern, self.content):
add_ioc_if_not_covered(match)
# URLs (check before domains to prevent double-matching)
# Fix: exclude trailing punctuation
url_pattern = r'https?://[^\s<>\"\']+(?<![.,;:!?\)\]\}])'
for match in re.finditer(url_pattern, self.content):
add_ioc_if_not_covered(match)
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.finditer(domain_pattern, self.content):
# Filter out common false positives
if not match.group().startswith('example.'):
add_ioc_if_not_covered(match)
# Email addresses # Email addresses
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b' email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
for match in re.findall(email_pattern, self.content): for match in re.finditer(email_pattern, self.content):
if match not in seen: add_ioc_if_not_covered(match)
seen.add(match)
self.iocs.append(match)
def calculate_hash(self): def calculate_hash(self):
# We hash the content + timestamp to ensure integrity of 'when' it was said # We hash the content + timestamp to ensure integrity of 'when' it was said
@@ -101,63 +104,66 @@ class Note:
"""Extract IOCs from text and return as list of (ioc, type) tuples""" """Extract IOCs from text and return as list of (ioc, type) tuples"""
iocs = [] iocs = []
seen = set() seen = set()
covered_ranges = set()
# IPv4 addresses def add_ioc_if_not_covered(match_obj, ioc_type):
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b' """Add IOC if its range doesn't overlap with already covered ranges"""
for match in re.findall(ipv4_pattern, text): start, end = match_obj.start(), match_obj.end()
if match not in seen: # Check if this range overlaps with any covered range
seen.add(match) for covered_start, covered_end in covered_ranges:
iocs.append((match, 'ipv4')) if not (end <= covered_start or start >= covered_end):
return False # Overlaps, don't add
ioc_text = match_obj.group()
if ioc_text not in seen:
seen.add(ioc_text)
covered_ranges.add((start, end))
iocs.append((ioc_text, ioc_type))
return True
return False
# IPv6 addresses (supports compressed format) # Process in priority order: longest hashes first
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b' # SHA256 hashes (64 hex chars)
for match in re.findall(ipv6_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'ipv6'))
# URLs (check before domains to avoid double-matching)
url_pattern = r'https?://[^\s]+'
for match in re.findall(url_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'url'))
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.findall(domain_pattern, text):
# Filter out common false positives and already seen URLs
if match not in seen and not match.startswith('example.'):
seen.add(match)
iocs.append((match, 'domain'))
# SHA256 hashes (64 hex chars) - check before SHA1 and MD5
sha256_pattern = r'\b[a-fA-F0-9]{64}\b' sha256_pattern = r'\b[a-fA-F0-9]{64}\b'
for match in re.findall(sha256_pattern, text): for match in re.finditer(sha256_pattern, text):
if match not in seen: add_ioc_if_not_covered(match, 'sha256')
seen.add(match)
iocs.append((match, 'sha256'))
# SHA1 hashes (40 hex chars) - check before MD5 # SHA1 hashes (40 hex chars)
sha1_pattern = r'\b[a-fA-F0-9]{40}\b' sha1_pattern = r'\b[a-fA-F0-9]{40}\b'
for match in re.findall(sha1_pattern, text): for match in re.finditer(sha1_pattern, text):
if match not in seen: add_ioc_if_not_covered(match, 'sha1')
seen.add(match)
iocs.append((match, 'sha1'))
# MD5 hashes (32 hex chars) # MD5 hashes (32 hex chars)
md5_pattern = r'\b[a-fA-F0-9]{32}\b' md5_pattern = r'\b[a-fA-F0-9]{32}\b'
for match in re.findall(md5_pattern, text): for match in re.finditer(md5_pattern, text):
if match not in seen: add_ioc_if_not_covered(match, 'md5')
seen.add(match)
iocs.append((match, 'md5')) # IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.finditer(ipv4_pattern, text):
add_ioc_if_not_covered(match, 'ipv4')
# IPv6 addresses (supports compressed format)
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b'
for match in re.finditer(ipv6_pattern, text):
add_ioc_if_not_covered(match, 'ipv6')
# URLs (check before domains to avoid double-matching)
# Fix: exclude trailing punctuation
url_pattern = r'https?://[^\s<>\"\']+(?<![.,;:!?\)\]\}])'
for match in re.finditer(url_pattern, text):
add_ioc_if_not_covered(match, 'url')
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.finditer(domain_pattern, text):
# Filter out common false positives
if not match.group().startswith('example.'):
add_ioc_if_not_covered(match, 'domain')
# Email addresses # Email addresses
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b' email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
for match in re.findall(email_pattern, text): for match in re.finditer(email_pattern, text):
if match not in seen: add_ioc_if_not_covered(match, 'email')
seen.add(match)
iocs.append((match, 'email'))
return iocs return iocs
@@ -182,6 +188,19 @@ class Note:
highlights.append((match.group(), start, end, ioc_type)) highlights.append((match.group(), start, end, ioc_type))
covered_ranges.add((start, end)) covered_ranges.add((start, end))
# Process in priority order: longest hashes first to avoid substring matches
# SHA256 hashes (64 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{64}\b', text):
add_highlight(match, 'sha256')
# SHA1 hashes (40 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{40}\b', text):
add_highlight(match, 'sha1')
# MD5 hashes (32 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{32}\b', text):
add_highlight(match, 'md5')
# IPv4 addresses # IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b' ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.finditer(ipv4_pattern, text): for match in re.finditer(ipv4_pattern, text):
@@ -193,7 +212,8 @@ class Note:
add_highlight(match, 'ipv6') add_highlight(match, 'ipv6')
# URLs (check before domains to prevent double-matching) # URLs (check before domains to prevent double-matching)
for match in re.finditer(r'https?://[^\s]+', text): # Fix: exclude trailing punctuation
for match in re.finditer(r'https?://[^\s<>\"\']+(?<![.,;:!?\)\]\}])', text):
add_highlight(match, 'url') add_highlight(match, 'url')
# Domain names # Domain names
@@ -201,18 +221,6 @@ class Note:
if not match.group().startswith('example.'): if not match.group().startswith('example.'):
add_highlight(match, 'domain') add_highlight(match, 'domain')
# SHA256 hashes (64 hex chars) - check longest first
for match in re.finditer(r'\b[a-fA-F0-9]{64}\b', text):
add_highlight(match, 'sha256')
# SHA1 hashes (40 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{40}\b', text):
add_highlight(match, 'sha1')
# MD5 hashes (32 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{32}\b', text):
add_highlight(match, 'md5')
# Email addresses # Email addresses
for match in re.finditer(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text): for match in re.finditer(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text):
add_highlight(match, 'email') add_highlight(match, 'email')

View File

@@ -1,22 +1,121 @@
import json import json
import time import time
import os
import sys
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from .models import Case, Evidence, Note from .models import Case, Evidence, Note
DEFAULT_APP_DIR = Path.home() / ".trace" DEFAULT_APP_DIR = Path.home() / ".trace"
class LockManager:
"""Cross-platform file lock manager to prevent concurrent access"""
def __init__(self, lock_file: Path):
self.lock_file = lock_file
self.acquired = False
def acquire(self, timeout: int = 5):
"""Acquire lock with timeout. Returns True if successful."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
# Try to create lock file exclusively (fails if exists)
# Use 'x' mode which fails if file exists (atomic on most systems)
fd = os.open(str(self.lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
os.write(fd, str(os.getpid()).encode())
os.close(fd)
self.acquired = True
return True
except FileExistsError:
# Lock file exists, check if process is still alive
if self._is_stale_lock():
# Remove stale lock and retry
try:
self.lock_file.unlink()
except FileNotFoundError:
pass
continue
# Active lock, wait a bit
time.sleep(0.1)
except Exception:
# Other errors, wait and retry
time.sleep(0.1)
return False
def _is_stale_lock(self):
"""Check if lock file is stale (process no longer exists)"""
try:
if not self.lock_file.exists():
return False
with open(self.lock_file, 'r') as f:
pid = int(f.read().strip())
# Check if process exists (cross-platform)
if sys.platform == 'win32':
import ctypes
kernel32 = ctypes.windll.kernel32
PROCESS_QUERY_INFORMATION = 0x0400
handle = kernel32.OpenProcess(PROCESS_QUERY_INFORMATION, 0, pid)
if handle:
kernel32.CloseHandle(handle)
return False
return True
else:
# Unix/Linux - send signal 0 to check if process exists
try:
os.kill(pid, 0)
return False # Process exists
except OSError:
return True # Process doesn't exist
except (ValueError, FileNotFoundError, PermissionError):
return True
def release(self):
"""Release the lock"""
if self.acquired:
try:
self.lock_file.unlink()
except FileNotFoundError:
pass
self.acquired = False
def __enter__(self):
if not self.acquire():
raise RuntimeError("Could not acquire lock: another instance is running")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
class Storage: class Storage:
def __init__(self, app_dir: Path = DEFAULT_APP_DIR): def __init__(self, app_dir: Path = DEFAULT_APP_DIR, acquire_lock: bool = True):
self.app_dir = app_dir self.app_dir = app_dir
self.data_file = self.app_dir / "data.json" self.data_file = self.app_dir / "data.json"
self.lock_file = self.app_dir / "app.lock"
self.lock_manager = None
self._ensure_app_dir() self._ensure_app_dir()
# Acquire lock to prevent concurrent access
if acquire_lock:
self.lock_manager = LockManager(self.lock_file)
if not self.lock_manager.acquire(timeout=5):
raise RuntimeError("Another instance of trace is already running. Please close it first.")
self.cases: List[Case] = self._load_data() self.cases: List[Case] = self._load_data()
# Create demo case on first launch # Create demo case on first launch (only if data loaded successfully and is empty)
if not self.cases: if not self.cases and self.data_file.exists():
# File exists but is empty - could be first run after successful load
pass
elif not self.cases and not self.data_file.exists():
# No file exists - first run
self._create_demo_case() self._create_demo_case()
def __del__(self):
"""Release lock when Storage object is destroyed"""
if self.lock_manager:
self.lock_manager.release()
def _ensure_app_dir(self): def _ensure_app_dir(self):
if not self.app_dir.exists(): if not self.app_dir.exists():
self.app_dir.mkdir(parents=True, exist_ok=True) self.app_dir.mkdir(parents=True, exist_ok=True)
@@ -169,8 +268,23 @@ Attachment: invoice.pdf.exe (double extension trick) #email-forensics #phishing-
with open(self.data_file, 'r', encoding='utf-8') as f: with open(self.data_file, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
return [Case.from_dict(c) for c in data] return [Case.from_dict(c) for c in data]
except (json.JSONDecodeError, IOError): except (json.JSONDecodeError, IOError, KeyError, ValueError) as e:
return [] # Corrupted JSON - create backup and raise exception
import shutil
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = self.app_dir / f"data.json.corrupted.{timestamp}"
try:
shutil.copy2(self.data_file, backup_file)
except Exception:
pass
# Raise exception with information about backup
raise RuntimeError(f"Data file is corrupted. Backup saved to: {backup_file}\nError: {e}")
def start_fresh(self):
"""Start with fresh data (for corrupted JSON recovery)"""
self.cases = []
self._create_demo_case()
def save_data(self): def save_data(self):
data = [c.to_dict() for c in self.cases] data = [c.to_dict() for c in self.cases]
@@ -238,6 +352,37 @@ class StateManager:
except (json.JSONDecodeError, IOError): except (json.JSONDecodeError, IOError):
return {"case_id": None, "evidence_id": None} return {"case_id": None, "evidence_id": None}
def validate_and_clear_stale(self, storage: 'Storage') -> str:
"""Validate active state against storage and clear stale references.
Returns warning message if state was cleared, empty string otherwise."""
state = self.get_active()
case_id = state.get("case_id")
evidence_id = state.get("evidence_id")
warning = ""
if case_id:
case = storage.get_case(case_id)
if not case:
warning = f"Active case (ID: {case_id[:8]}...) no longer exists. Clearing active context."
self.set_active(None, None)
return warning
# Validate evidence if set
if evidence_id:
_, evidence = storage.find_evidence(evidence_id)
if not evidence:
warning = f"Active evidence (ID: {evidence_id[:8]}...) no longer exists. Clearing to case level."
self.set_active(case_id, None)
return warning
elif evidence_id:
# Evidence set but no case - invalid state
warning = "Invalid state: evidence set without case. Clearing active context."
self.set_active(None, None)
return warning
return warning
def get_settings(self) -> dict: def get_settings(self) -> dict:
if not self.settings_file.exists(): if not self.settings_file.exists():
return {"pgp_enabled": True} return {"pgp_enabled": True}

View File

@@ -21,7 +21,8 @@ class TestModels(unittest.TestCase):
class TestStorage(unittest.TestCase): class TestStorage(unittest.TestCase):
def setUp(self): def setUp(self):
self.test_dir = Path(tempfile.mkdtemp()) self.test_dir = Path(tempfile.mkdtemp())
self.storage = Storage(app_dir=self.test_dir) # Disable lock for tests to allow multiple Storage instances
self.storage = Storage(app_dir=self.test_dir, acquire_lock=False)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.test_dir) shutil.rmtree(self.test_dir)
@@ -31,7 +32,7 @@ class TestStorage(unittest.TestCase):
self.storage.add_case(case) self.storage.add_case(case)
# Reload storage from same dir # Reload storage from same dir
new_storage = Storage(app_dir=self.test_dir) new_storage = Storage(app_dir=self.test_dir, acquire_lock=False)
loaded_case = new_storage.get_case(case.case_id) loaded_case = new_storage.get_case(case.case_id)
self.assertIsNotNone(loaded_case) self.assertIsNotNone(loaded_case)

View File

@@ -64,7 +64,12 @@ class TUI:
self.height, self.width = stdscr.getmaxyx() self.height, self.width = stdscr.getmaxyx()
# Load initial active state # Load initial active state and validate
warning = self.state_manager.validate_and_clear_stale(self.storage)
if warning:
self.flash_message = warning
self.flash_time = time.time()
active_state = self.state_manager.get_active() active_state = self.state_manager.get_active()
self.global_active_case_id = active_state.get("case_id") self.global_active_case_id = active_state.get("case_id")
self.global_active_evidence_id = active_state.get("evidence_id") self.global_active_evidence_id = active_state.get("evidence_id")
@@ -445,14 +450,35 @@ class TUI:
def _update_scroll(self, total_items): def _update_scroll(self, total_items):
# Viewport height calculation (approximate lines available for list) # Viewport height calculation (approximate lines available for list)
list_h = self.content_h - 2 # Title + padding # Protect against negative or zero content_h
if list_h < 1: list_h = 1 if self.content_h < 3:
# Terminal too small, use minimum viable height
list_h = 1
else:
list_h = self.content_h - 2 # Title + padding
if list_h < 1:
list_h = 1
# Ensure selected index is visible # Ensure selected index is within bounds
if self.selected_index < self.scroll_offset: if total_items == 0:
self.scroll_offset = self.selected_index self.selected_index = 0
elif self.selected_index >= self.scroll_offset + list_h: self.scroll_offset = 0
self.scroll_offset = self.selected_index - list_h + 1 else:
# Clamp selected_index to valid range
if self.selected_index >= total_items:
self.selected_index = max(0, total_items - 1)
if self.selected_index < 0:
self.selected_index = 0
# Ensure selected index is visible
if self.selected_index < self.scroll_offset:
self.scroll_offset = self.selected_index
elif self.selected_index >= self.scroll_offset + list_h:
self.scroll_offset = self.selected_index - list_h + 1
# Ensure scroll_offset is within bounds
if self.scroll_offset < 0:
self.scroll_offset = 0
return list_h return list_h
@@ -819,21 +845,23 @@ class TUI:
self.stdscr.addstr(current_y, 2, f"Notes ({len(notes)}):", curses.A_UNDERLINE) self.stdscr.addstr(current_y, 2, f"Notes ({len(notes)}):", curses.A_UNDERLINE)
current_y += 1 current_y += 1
# Just show last N notes that fit # Calculate available height for notes list
list_h = self.content_h - (current_y - 2) # Adjust for dynamic header list_h = self.content_h - (current_y - 2) # Adjust for dynamic header
if list_h < 1:
list_h = 1
start_y = current_y start_y = current_y
display_notes = notes[-list_h:] if len(notes) > list_h else notes # Update scroll to keep selection visible (use full notes list)
if notes:
self._update_scroll(len(notes))
# Update scroll for note selection # Display notes with proper scrolling
if display_notes: for i in range(list_h):
self._update_scroll(len(display_notes))
for i, note in enumerate(display_notes):
idx = self.scroll_offset + i idx = self.scroll_offset + i
if idx >= len(display_notes): if idx >= len(notes):
break break
note = display_notes[idx]
note = notes[idx]
# Replace newlines with spaces for single-line display # Replace newlines with spaces for single-line display
note_content = note.content.replace('\n', ' ').replace('\r', ' ') note_content = note.content.replace('\n', ' ').replace('\r', ' ')
display_str = f"- {note_content}" display_str = f"- {note_content}"
@@ -1582,22 +1610,75 @@ class TUI:
self.filter_query = "" self.filter_query = ""
self.selected_index = 0 self.selected_index = 0
self.scroll_offset = 0 self.scroll_offset = 0
# Validate selected_index against the unfiltered list
self._validate_selection_bounds()
return True return True
elif key == curses.KEY_ENTER or key in [10, 13]: elif key == curses.KEY_ENTER or key in [10, 13]:
self.filter_mode = False self.filter_mode = False
self.selected_index = 0 # Validate selected_index against the filtered list
self._validate_selection_bounds()
self.scroll_offset = 0 self.scroll_offset = 0
return True return True
elif key == curses.KEY_BACKSPACE or key == 127: elif key == curses.KEY_BACKSPACE or key == 127:
if len(self.filter_query) > 0: if len(self.filter_query) > 0:
self.filter_query = self.filter_query[:-1] self.filter_query = self.filter_query[:-1]
self.selected_index = 0 self.selected_index = 0
self.scroll_offset = 0
elif 32 <= key <= 126: elif 32 <= key <= 126:
self.filter_query += chr(key) self.filter_query += chr(key)
self.selected_index = 0 self.selected_index = 0
self.scroll_offset = 0
return True return True
def _validate_selection_bounds(self):
"""Validate and fix selected_index and scroll_offset to ensure they're within bounds"""
max_idx = 0
if self.current_view == "case_list":
filtered = self._get_filtered_list(self.cases, "case_number", "name")
max_idx = len(filtered) - 1
elif self.current_view == "case_detail" and self.active_case:
case_notes = self.active_case.notes
filtered = self._get_filtered_list(self.active_case.evidence, "name", "description")
max_idx = len(filtered) + len(case_notes) - 1
elif self.current_view == "evidence_detail" and self.active_evidence:
notes = self._get_filtered_list(self.active_evidence.notes, "content") if self.filter_query else self.active_evidence.notes
max_idx = len(notes) - 1
elif self.current_view == "tags_list":
tags_to_show = self.current_tags
if self.filter_query:
q = self.filter_query.lower()
tags_to_show = [(tag, count) for tag, count in self.current_tags if q in tag.lower()]
max_idx = len(tags_to_show) - 1
elif self.current_view == "tag_notes_list":
notes_to_show = self._get_filtered_list(self.tag_notes, "content") if self.filter_query else self.tag_notes
max_idx = len(notes_to_show) - 1
elif self.current_view == "ioc_list":
iocs_to_show = self.current_iocs
if self.filter_query:
q = self.filter_query.lower()
iocs_to_show = [(ioc, count, ioc_type) for ioc, count, ioc_type in self.current_iocs
if q in ioc.lower() or q in ioc_type.lower()]
max_idx = len(iocs_to_show) - 1
elif self.current_view == "ioc_notes_list":
notes_to_show = self._get_filtered_list(self.ioc_notes, "content") if self.filter_query else self.ioc_notes
max_idx = len(notes_to_show) - 1
# Ensure max_idx is at least 0
max_idx = max(0, max_idx)
# Fix selected_index if out of bounds
if self.selected_index > max_idx:
self.selected_index = max_idx
if self.selected_index < 0:
self.selected_index = 0
# Fix scroll_offset if out of bounds
if self.scroll_offset > self.selected_index:
self.scroll_offset = max(0, self.selected_index)
if self.scroll_offset < 0:
self.scroll_offset = 0
def _handle_set_active(self): def _handle_set_active(self):
if self.current_view == "case_list": if self.current_view == "case_list":
filtered = self._get_filtered_list(self.cases, "case_number", "name") filtered = self._get_filtered_list(self.cases, "case_number", "name")
@@ -2477,24 +2558,21 @@ class TUI:
self.show_message("No notes to delete.") self.show_message("No notes to delete.")
return return
# Calculate which note to delete based on display (showing last N filtered notes) # Get filtered notes (or all notes if no filter)
notes = self._get_filtered_list(self.active_evidence.notes, "content") if self.filter_query else self.active_evidence.notes notes = self._get_filtered_list(self.active_evidence.notes, "content") if self.filter_query else self.active_evidence.notes
list_h = self.content_h - 5 # Adjust for header
display_notes = notes[-list_h:] if len(notes) > list_h else notes
if display_notes: if notes and self.selected_index < len(notes):
# User selection is in context of displayed notes note_to_del = notes[self.selected_index]
# We need to delete from the full list # Show preview of note content in confirmation
if self.selected_index < len(display_notes): preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content
note_to_del = display_notes[self.selected_index] if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Show preview of note content in confirmation self.active_evidence.notes.remove(note_to_del)
preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content self.storage.save_data()
if self.dialog_confirm(f"Delete note: '{preview}'?"): # Adjust selected index if needed
self.active_evidence.notes.remove(note_to_del) if self.selected_index >= len(notes) - 1:
self.storage.save_data() self.selected_index = max(0, len(notes) - 2)
self.selected_index = 0 self.scroll_offset = max(0, min(self.scroll_offset, self.selected_index))
self.scroll_offset = 0 self.show_message("Note deleted.")
self.show_message("Note deleted.")
elif self.current_view == "note_detail": elif self.current_view == "note_detail":
# Delete the currently viewed note # Delete the currently viewed note
@@ -2503,18 +2581,25 @@ class TUI:
preview = self.current_note.content[:50] + "..." if len(self.current_note.content) > 50 else self.current_note.content preview = self.current_note.content[:50] + "..." if len(self.current_note.content) > 50 else self.current_note.content
if self.dialog_confirm(f"Delete note: '{preview}'?"): if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Find and delete the note from its parent (case or evidence) # Find and delete the note from its parent (case or evidence) using note_id
deleted = False deleted = False
note_id = self.current_note.note_id
# Check all cases and their evidence for this note # Check all cases and their evidence for this note
for case in self.cases: for case in self.cases:
if self.current_note in case.notes: for note in case.notes:
case.notes.remove(self.current_note) if note.note_id == note_id:
deleted = True case.notes.remove(note)
deleted = True
break
if deleted:
break break
for ev in case.evidence: for ev in case.evidence:
if self.current_note in ev.notes: for note in ev.notes:
ev.notes.remove(self.current_note) if note.note_id == note_id:
deleted = True ev.notes.remove(note)
deleted = True
break
if deleted:
break break
if deleted: if deleted:
break break
@@ -2539,17 +2624,24 @@ class TUI:
note_to_del = notes_to_show[self.selected_index] note_to_del = notes_to_show[self.selected_index]
preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content
if self.dialog_confirm(f"Delete note: '{preview}'?"): if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Find and delete the note from its parent # Find and delete the note from its parent using note_id
deleted = False deleted = False
note_id = note_to_del.note_id
for case in self.cases: for case in self.cases:
if note_to_del in case.notes: for note in case.notes:
case.notes.remove(note_to_del) if note.note_id == note_id:
deleted = True case.notes.remove(note)
deleted = True
break
if deleted:
break break
for ev in case.evidence: for ev in case.evidence:
if note_to_del in ev.notes: for note in ev.notes:
ev.notes.remove(note_to_del) if note.note_id == note_id:
deleted = True ev.notes.remove(note)
deleted = True
break
if deleted:
break break
if deleted: if deleted:
break break
@@ -2557,9 +2649,9 @@ class TUI:
if deleted: if deleted:
self.storage.save_data() self.storage.save_data()
# Remove from tag_notes list as well # Remove from tag_notes list as well
self.tag_notes.remove(note_to_del) self.tag_notes = [n for n in self.tag_notes if n.note_id != note_id]
self.selected_index = min(self.selected_index, len(self.tag_notes) - 1) if self.tag_notes else 0 self.selected_index = min(self.selected_index, len(self.tag_notes) - 1) if self.tag_notes else 0
self.scroll_offset = 0 self.scroll_offset = max(0, min(self.scroll_offset, self.selected_index))
self.show_message("Note deleted.") self.show_message("Note deleted.")
else: else:
self.show_message("Error: Note not found.") self.show_message("Error: Note not found.")
@@ -2573,17 +2665,24 @@ class TUI:
note_to_del = notes_to_show[self.selected_index] note_to_del = notes_to_show[self.selected_index]
preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content
if self.dialog_confirm(f"Delete note: '{preview}'?"): if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Find and delete the note from its parent # Find and delete the note from its parent using note_id
deleted = False deleted = False
note_id = note_to_del.note_id
for case in self.cases: for case in self.cases:
if note_to_del in case.notes: for note in case.notes:
case.notes.remove(note_to_del) if note.note_id == note_id:
deleted = True case.notes.remove(note)
deleted = True
break
if deleted:
break break
for ev in case.evidence: for ev in case.evidence:
if note_to_del in ev.notes: for note in ev.notes:
ev.notes.remove(note_to_del) if note.note_id == note_id:
deleted = True ev.notes.remove(note)
deleted = True
break
if deleted:
break break
if deleted: if deleted:
break break
@@ -2591,9 +2690,9 @@ class TUI:
if deleted: if deleted:
self.storage.save_data() self.storage.save_data()
# Remove from ioc_notes list as well # Remove from ioc_notes list as well
self.ioc_notes.remove(note_to_del) self.ioc_notes = [n for n in self.ioc_notes if n.note_id != note_id]
self.selected_index = min(self.selected_index, len(self.ioc_notes) - 1) if self.ioc_notes else 0 self.selected_index = min(self.selected_index, len(self.ioc_notes) - 1) if self.ioc_notes else 0
self.scroll_offset = 0 self.scroll_offset = max(0, min(self.scroll_offset, self.selected_index))
self.show_message("Note deleted.") self.show_message("Note deleted.")
else: else:
self.show_message("Error: Note not found.") self.show_message("Error: Note not found.")
@@ -3041,7 +3140,109 @@ def run_tui(open_active=False):
open_active: If True, navigate directly to the active case/evidence view open_active: If True, navigate directly to the active case/evidence view
""" """
def tui_wrapper(stdscr): def tui_wrapper(stdscr):
tui = TUI(stdscr) try:
tui = TUI(stdscr)
except RuntimeError as e:
# Handle corrupted JSON data
error_msg = str(e)
if "corrupted" in error_msg.lower():
# Show corruption dialog
stdscr.clear()
h, w = stdscr.getmaxyx()
# Display error message
lines = [
"╔══════════════════════════════════════════════════════════╗",
"║ DATA FILE CORRUPTION DETECTED ║",
"╚══════════════════════════════════════════════════════════╝",
"",
"Your data file appears to be corrupted.",
"",
] + error_msg.split('\n')[:5] + [
"",
"Options:",
" [1] Start fresh (backup already created)",
" [2] Exit and manually recover from backup",
"",
"Press 1 or 2 to continue..."
]
start_y = max(0, (h - len(lines)) // 2)
for i, line in enumerate(lines):
if start_y + i < h - 1:
display_line = line[:w-2] if len(line) > w-2 else line
try:
stdscr.addstr(start_y + i, 2, display_line)
except curses.error:
pass
stdscr.refresh()
# Wait for user choice
while True:
key = stdscr.getch()
if key == ord('1'):
# Start fresh - need to create storage with empty data
from .storage import Storage, DEFAULT_APP_DIR
storage = Storage.__new__(Storage)
storage.app_dir = DEFAULT_APP_DIR
storage.data_file = storage.app_dir / "data.json"
storage.lock_file = storage.app_dir / "app.lock"
storage.lock_manager = None
storage._ensure_app_dir()
# Acquire lock
from .storage import LockManager
storage.lock_manager = LockManager(storage.lock_file)
if not storage.lock_manager.acquire(timeout=5):
raise RuntimeError("Another instance is running")
storage.start_fresh()
# Create TUI with the fresh storage
tui = TUI.__new__(TUI)
tui.stdscr = stdscr
tui.storage = storage
tui.state_manager = StateManager()
tui.current_view = "case_list"
tui.selected_index = 0
tui.scroll_offset = 0
tui.cases = tui.storage.cases
tui.active_case = None
tui.active_evidence = None
tui.current_tags = []
tui.current_tag = None
tui.tag_notes = []
tui.current_note = None
tui.current_iocs = []
tui.current_ioc = None
tui.ioc_notes = []
tui.filter_mode = False
tui.filter_query = ""
tui.flash_message = "Started with fresh data. Backup of corrupted file was created."
tui.flash_time = time.time()
curses.curs_set(0)
curses.start_color()
if curses.has_colors():
curses.init_pair(1, curses.COLOR_BLACK, curses.COLOR_CYAN)
curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)
curses.init_pair(3, curses.COLOR_YELLOW, curses.COLOR_BLACK)
curses.init_pair(4, curses.COLOR_RED, curses.COLOR_BLACK)
curses.init_pair(5, curses.COLOR_CYAN, curses.COLOR_BLACK)
curses.init_pair(6, curses.COLOR_WHITE, curses.COLOR_BLACK)
curses.init_pair(7, curses.COLOR_BLUE, curses.COLOR_BLACK)
curses.init_pair(8, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
curses.init_pair(9, curses.COLOR_RED, curses.COLOR_CYAN)
curses.init_pair(10, curses.COLOR_YELLOW, curses.COLOR_CYAN)
tui.height, tui.width = stdscr.getmaxyx()
active_state = tui.state_manager.get_active()
tui.global_active_case_id = active_state.get("case_id")
tui.global_active_evidence_id = active_state.get("evidence_id")
break
elif key == ord('2') or key == ord('q'):
# Exit
return
else:
# Re-raise if it's a different RuntimeError
raise
# If requested, navigate to active case/evidence # If requested, navigate to active case/evidence
if open_active: if open_active: