Merge pull request #7 from overcuriousity/claude/debug-code-issues-01ANayVVF2LaNAabfcL6G49y

Fix critical bugs and improve data integrity across codebase
This commit is contained in:
overcuriousity
2025-12-13 18:18:44 +01:00
committed by GitHub
6 changed files with 547 additions and 178 deletions

View File

@@ -8,6 +8,12 @@ from .crypto import Crypto
def quick_add_note(content: str):
storage = Storage()
state_manager = StateManager()
# Validate and clear stale state
warning = state_manager.validate_and_clear_stale(storage)
if warning:
print(f"Warning: {warning}", file=sys.stderr)
state = state_manager.get_active()
settings = state_manager.get_settings()
@@ -15,23 +21,28 @@ def quick_add_note(content: str):
evidence_id = state.get("evidence_id")
if not case_id:
print("Error: No active case set. Open the TUI to select a case first.")
print("Error: No active case set. Open the TUI to select a case first.", file=sys.stderr)
sys.exit(1)
case = storage.get_case(case_id)
if not case:
print("Error: Active case not found in storage. Ensure you have set an active case in the TUI.")
print("Error: Active case not found in storage. Ensure you have set an active case in the TUI.", file=sys.stderr)
sys.exit(1)
target_evidence = None
if evidence_id:
# Find evidence
# Find and validate evidence belongs to active case
for ev in case.evidence:
if ev.evidence_id == evidence_id:
target_evidence = ev
break
if not target_evidence:
# Evidence ID is set but doesn't exist in case - clear it
print(f"Warning: Active evidence not found in case. Clearing to case level.", file=sys.stderr)
state_manager.set_active(case_id, None)
# Create note
note = Note(content=content)
note.calculate_hash()
@@ -47,9 +58,9 @@ def quick_add_note(content: str):
if signature:
note.signature = signature
else:
print("Warning: GPG signature failed (GPG not found or no key). Note saved without signature.")
print("Warning: GPG signature failed (GPG not found or no key). Note saved without signature.", file=sys.stderr)
else:
print("Warning: No GPG key ID configured. Note saved without signature.")
print("Warning: No GPG key ID configured. Note saved without signature.", file=sys.stderr)
# Attach to evidence or case
if target_evidence:
@@ -117,7 +128,10 @@ def export_markdown(output_file: str = "export.md"):
def write_note(f, note: Note):
f.write(f"- **{time.ctime(note.timestamp)}**\n")
f.write(f" - Content: {note.content}\n")
f.write(f" - Content:\n")
# Properly indent multi-line content
for line in note.content.splitlines():
f.write(f" {line}\n")
f.write(f" - Hash: `{note.content_hash}`\n")
if note.signature:
f.write(" - **Signature Verified:**\n")

View File

@@ -37,7 +37,7 @@ class Crypto:
elif fields[0] == 'uid' and current_key_id:
user_id = fields[9] if len(fields) > 9 else "Unknown"
keys.append((current_key_id, user_id))
current_key_id = None # Reset after matching
# Don't reset current_key_id - allow multiple UIDs per key
return keys

View File

@@ -32,64 +32,67 @@ class Note:
def extract_iocs(self):
"""Extract Indicators of Compromise from content"""
seen = set()
covered_ranges = set()
self.iocs = []
# IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.findall(ipv4_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
def add_ioc_if_not_covered(match_obj):
"""Add IOC if its range doesn't overlap with already covered ranges"""
start, end = match_obj.start(), match_obj.end()
# Check if this range overlaps with any covered range
for covered_start, covered_end in covered_ranges:
if not (end <= covered_start or start >= covered_end):
return False # Overlaps, don't add
text = match_obj.group()
if text not in seen:
seen.add(text)
covered_ranges.add((start, end))
self.iocs.append(text)
return True
return False
# IPv6 addresses (supports compressed format)
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b'
for match in re.findall(ipv6_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
# URLs (check before domains to prevent double-matching)
url_pattern = r'https?://[^\s]+'
for match in re.findall(url_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.findall(domain_pattern, self.content):
# Filter out common false positives
if match not in seen and not match.startswith('example.'):
seen.add(match)
self.iocs.append(match)
# SHA256 hashes (64 hex chars) - check longest first
# Process in order of priority to avoid false positives
# SHA256 hashes (64 hex chars) - check longest first to avoid substring matches
sha256_pattern = r'\b[a-fA-F0-9]{64}\b'
for match in re.findall(sha256_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
for match in re.finditer(sha256_pattern, self.content):
add_ioc_if_not_covered(match)
# SHA1 hashes (40 hex chars)
sha1_pattern = r'\b[a-fA-F0-9]{40}\b'
for match in re.findall(sha1_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
for match in re.finditer(sha1_pattern, self.content):
add_ioc_if_not_covered(match)
# MD5 hashes (32 hex chars)
md5_pattern = r'\b[a-fA-F0-9]{32}\b'
for match in re.findall(md5_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
for match in re.finditer(md5_pattern, self.content):
add_ioc_if_not_covered(match)
# IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.finditer(ipv4_pattern, self.content):
add_ioc_if_not_covered(match)
# IPv6 addresses (supports compressed format)
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b'
for match in re.finditer(ipv6_pattern, self.content):
add_ioc_if_not_covered(match)
# URLs (check before domains to prevent double-matching)
# Fix: exclude trailing punctuation
url_pattern = r'https?://[^\s<>\"\']+(?<![.,;:!?\)\]\}])'
for match in re.finditer(url_pattern, self.content):
add_ioc_if_not_covered(match)
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.finditer(domain_pattern, self.content):
# Filter out common false positives
if not match.group().startswith('example.'):
add_ioc_if_not_covered(match)
# Email addresses
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
for match in re.findall(email_pattern, self.content):
if match not in seen:
seen.add(match)
self.iocs.append(match)
for match in re.finditer(email_pattern, self.content):
add_ioc_if_not_covered(match)
def calculate_hash(self):
# We hash the content + timestamp to ensure integrity of 'when' it was said
@@ -101,63 +104,66 @@ class Note:
"""Extract IOCs from text and return as list of (ioc, type) tuples"""
iocs = []
seen = set()
covered_ranges = set()
# IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.findall(ipv4_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'ipv4'))
def add_ioc_if_not_covered(match_obj, ioc_type):
"""Add IOC if its range doesn't overlap with already covered ranges"""
start, end = match_obj.start(), match_obj.end()
# Check if this range overlaps with any covered range
for covered_start, covered_end in covered_ranges:
if not (end <= covered_start or start >= covered_end):
return False # Overlaps, don't add
ioc_text = match_obj.group()
if ioc_text not in seen:
seen.add(ioc_text)
covered_ranges.add((start, end))
iocs.append((ioc_text, ioc_type))
return True
return False
# IPv6 addresses (supports compressed format)
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b'
for match in re.findall(ipv6_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'ipv6'))
# URLs (check before domains to avoid double-matching)
url_pattern = r'https?://[^\s]+'
for match in re.findall(url_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'url'))
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.findall(domain_pattern, text):
# Filter out common false positives and already seen URLs
if match not in seen and not match.startswith('example.'):
seen.add(match)
iocs.append((match, 'domain'))
# SHA256 hashes (64 hex chars) - check before SHA1 and MD5
# Process in priority order: longest hashes first
# SHA256 hashes (64 hex chars)
sha256_pattern = r'\b[a-fA-F0-9]{64}\b'
for match in re.findall(sha256_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'sha256'))
for match in re.finditer(sha256_pattern, text):
add_ioc_if_not_covered(match, 'sha256')
# SHA1 hashes (40 hex chars) - check before MD5
# SHA1 hashes (40 hex chars)
sha1_pattern = r'\b[a-fA-F0-9]{40}\b'
for match in re.findall(sha1_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'sha1'))
for match in re.finditer(sha1_pattern, text):
add_ioc_if_not_covered(match, 'sha1')
# MD5 hashes (32 hex chars)
md5_pattern = r'\b[a-fA-F0-9]{32}\b'
for match in re.findall(md5_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'md5'))
for match in re.finditer(md5_pattern, text):
add_ioc_if_not_covered(match, 'md5')
# IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.finditer(ipv4_pattern, text):
add_ioc_if_not_covered(match, 'ipv4')
# IPv6 addresses (supports compressed format)
ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b|\b(?:[0-9a-fA-F]{1,4}:)*::(?:[0-9a-fA-F]{1,4}:)*[0-9a-fA-F]{0,4}\b'
for match in re.finditer(ipv6_pattern, text):
add_ioc_if_not_covered(match, 'ipv6')
# URLs (check before domains to avoid double-matching)
# Fix: exclude trailing punctuation
url_pattern = r'https?://[^\s<>\"\']+(?<![.,;:!?\)\]\}])'
for match in re.finditer(url_pattern, text):
add_ioc_if_not_covered(match, 'url')
# Domain names (basic pattern)
domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
for match in re.finditer(domain_pattern, text):
# Filter out common false positives
if not match.group().startswith('example.'):
add_ioc_if_not_covered(match, 'domain')
# Email addresses
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
for match in re.findall(email_pattern, text):
if match not in seen:
seen.add(match)
iocs.append((match, 'email'))
for match in re.finditer(email_pattern, text):
add_ioc_if_not_covered(match, 'email')
return iocs
@@ -182,6 +188,19 @@ class Note:
highlights.append((match.group(), start, end, ioc_type))
covered_ranges.add((start, end))
# Process in priority order: longest hashes first to avoid substring matches
# SHA256 hashes (64 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{64}\b', text):
add_highlight(match, 'sha256')
# SHA1 hashes (40 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{40}\b', text):
add_highlight(match, 'sha1')
# MD5 hashes (32 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{32}\b', text):
add_highlight(match, 'md5')
# IPv4 addresses
ipv4_pattern = r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
for match in re.finditer(ipv4_pattern, text):
@@ -193,7 +212,8 @@ class Note:
add_highlight(match, 'ipv6')
# URLs (check before domains to prevent double-matching)
for match in re.finditer(r'https?://[^\s]+', text):
# Fix: exclude trailing punctuation
for match in re.finditer(r'https?://[^\s<>\"\']+(?<![.,;:!?\)\]\}])', text):
add_highlight(match, 'url')
# Domain names
@@ -201,18 +221,6 @@ class Note:
if not match.group().startswith('example.'):
add_highlight(match, 'domain')
# SHA256 hashes (64 hex chars) - check longest first
for match in re.finditer(r'\b[a-fA-F0-9]{64}\b', text):
add_highlight(match, 'sha256')
# SHA1 hashes (40 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{40}\b', text):
add_highlight(match, 'sha1')
# MD5 hashes (32 hex chars)
for match in re.finditer(r'\b[a-fA-F0-9]{32}\b', text):
add_highlight(match, 'md5')
# Email addresses
for match in re.finditer(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text):
add_highlight(match, 'email')

View File

@@ -1,22 +1,121 @@
import json
import time
import os
import sys
from pathlib import Path
from typing import List, Optional, Tuple
from .models import Case, Evidence, Note
DEFAULT_APP_DIR = Path.home() / ".trace"
class LockManager:
"""Cross-platform file lock manager to prevent concurrent access"""
def __init__(self, lock_file: Path):
self.lock_file = lock_file
self.acquired = False
def acquire(self, timeout: int = 5):
"""Acquire lock with timeout. Returns True if successful."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
# Try to create lock file exclusively (fails if exists)
# Use 'x' mode which fails if file exists (atomic on most systems)
fd = os.open(str(self.lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
os.write(fd, str(os.getpid()).encode())
os.close(fd)
self.acquired = True
return True
except FileExistsError:
# Lock file exists, check if process is still alive
if self._is_stale_lock():
# Remove stale lock and retry
try:
self.lock_file.unlink()
except FileNotFoundError:
pass
continue
# Active lock, wait a bit
time.sleep(0.1)
except Exception:
# Other errors, wait and retry
time.sleep(0.1)
return False
def _is_stale_lock(self):
"""Check if lock file is stale (process no longer exists)"""
try:
if not self.lock_file.exists():
return False
with open(self.lock_file, 'r') as f:
pid = int(f.read().strip())
# Check if process exists (cross-platform)
if sys.platform == 'win32':
import ctypes
kernel32 = ctypes.windll.kernel32
PROCESS_QUERY_INFORMATION = 0x0400
handle = kernel32.OpenProcess(PROCESS_QUERY_INFORMATION, 0, pid)
if handle:
kernel32.CloseHandle(handle)
return False
return True
else:
# Unix/Linux - send signal 0 to check if process exists
try:
os.kill(pid, 0)
return False # Process exists
except OSError:
return True # Process doesn't exist
except (ValueError, FileNotFoundError, PermissionError):
return True
def release(self):
"""Release the lock"""
if self.acquired:
try:
self.lock_file.unlink()
except FileNotFoundError:
pass
self.acquired = False
def __enter__(self):
if not self.acquire():
raise RuntimeError("Could not acquire lock: another instance is running")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
class Storage:
def __init__(self, app_dir: Path = DEFAULT_APP_DIR):
def __init__(self, app_dir: Path = DEFAULT_APP_DIR, acquire_lock: bool = True):
self.app_dir = app_dir
self.data_file = self.app_dir / "data.json"
self.lock_file = self.app_dir / "app.lock"
self.lock_manager = None
self._ensure_app_dir()
# Acquire lock to prevent concurrent access
if acquire_lock:
self.lock_manager = LockManager(self.lock_file)
if not self.lock_manager.acquire(timeout=5):
raise RuntimeError("Another instance of trace is already running. Please close it first.")
self.cases: List[Case] = self._load_data()
# Create demo case on first launch
if not self.cases:
# Create demo case on first launch (only if data loaded successfully and is empty)
if not self.cases and self.data_file.exists():
# File exists but is empty - could be first run after successful load
pass
elif not self.cases and not self.data_file.exists():
# No file exists - first run
self._create_demo_case()
def __del__(self):
"""Release lock when Storage object is destroyed"""
if self.lock_manager:
self.lock_manager.release()
def _ensure_app_dir(self):
if not self.app_dir.exists():
self.app_dir.mkdir(parents=True, exist_ok=True)
@@ -169,8 +268,23 @@ Attachment: invoice.pdf.exe (double extension trick) #email-forensics #phishing-
with open(self.data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return [Case.from_dict(c) for c in data]
except (json.JSONDecodeError, IOError):
return []
except (json.JSONDecodeError, IOError, KeyError, ValueError) as e:
# Corrupted JSON - create backup and raise exception
import shutil
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = self.app_dir / f"data.json.corrupted.{timestamp}"
try:
shutil.copy2(self.data_file, backup_file)
except Exception:
pass
# Raise exception with information about backup
raise RuntimeError(f"Data file is corrupted. Backup saved to: {backup_file}\nError: {e}")
def start_fresh(self):
"""Start with fresh data (for corrupted JSON recovery)"""
self.cases = []
self._create_demo_case()
def save_data(self):
data = [c.to_dict() for c in self.cases]
@@ -238,6 +352,37 @@ class StateManager:
except (json.JSONDecodeError, IOError):
return {"case_id": None, "evidence_id": None}
def validate_and_clear_stale(self, storage: 'Storage') -> str:
"""Validate active state against storage and clear stale references.
Returns warning message if state was cleared, empty string otherwise."""
state = self.get_active()
case_id = state.get("case_id")
evidence_id = state.get("evidence_id")
warning = ""
if case_id:
case = storage.get_case(case_id)
if not case:
warning = f"Active case (ID: {case_id[:8]}...) no longer exists. Clearing active context."
self.set_active(None, None)
return warning
# Validate evidence if set
if evidence_id:
_, evidence = storage.find_evidence(evidence_id)
if not evidence:
warning = f"Active evidence (ID: {evidence_id[:8]}...) no longer exists. Clearing to case level."
self.set_active(case_id, None)
return warning
elif evidence_id:
# Evidence set but no case - invalid state
warning = "Invalid state: evidence set without case. Clearing active context."
self.set_active(None, None)
return warning
return warning
def get_settings(self) -> dict:
if not self.settings_file.exists():
return {"pgp_enabled": True}

View File

@@ -21,7 +21,8 @@ class TestModels(unittest.TestCase):
class TestStorage(unittest.TestCase):
def setUp(self):
self.test_dir = Path(tempfile.mkdtemp())
self.storage = Storage(app_dir=self.test_dir)
# Disable lock for tests to allow multiple Storage instances
self.storage = Storage(app_dir=self.test_dir, acquire_lock=False)
def tearDown(self):
shutil.rmtree(self.test_dir)
@@ -31,7 +32,7 @@ class TestStorage(unittest.TestCase):
self.storage.add_case(case)
# Reload storage from same dir
new_storage = Storage(app_dir=self.test_dir)
new_storage = Storage(app_dir=self.test_dir, acquire_lock=False)
loaded_case = new_storage.get_case(case.case_id)
self.assertIsNotNone(loaded_case)

View File

@@ -64,7 +64,12 @@ class TUI:
self.height, self.width = stdscr.getmaxyx()
# Load initial active state
# Load initial active state and validate
warning = self.state_manager.validate_and_clear_stale(self.storage)
if warning:
self.flash_message = warning
self.flash_time = time.time()
active_state = self.state_manager.get_active()
self.global_active_case_id = active_state.get("case_id")
self.global_active_evidence_id = active_state.get("evidence_id")
@@ -445,8 +450,25 @@ class TUI:
def _update_scroll(self, total_items):
# Viewport height calculation (approximate lines available for list)
# Protect against negative or zero content_h
if self.content_h < 3:
# Terminal too small, use minimum viable height
list_h = 1
else:
list_h = self.content_h - 2 # Title + padding
if list_h < 1: list_h = 1
if list_h < 1:
list_h = 1
# Ensure selected index is within bounds
if total_items == 0:
self.selected_index = 0
self.scroll_offset = 0
else:
# Clamp selected_index to valid range
if self.selected_index >= total_items:
self.selected_index = max(0, total_items - 1)
if self.selected_index < 0:
self.selected_index = 0
# Ensure selected index is visible
if self.selected_index < self.scroll_offset:
@@ -454,6 +476,10 @@ class TUI:
elif self.selected_index >= self.scroll_offset + list_h:
self.scroll_offset = self.selected_index - list_h + 1
# Ensure scroll_offset is within bounds
if self.scroll_offset < 0:
self.scroll_offset = 0
return list_h
def _get_filtered_list(self, items, key_attr=None, key_attr2=None):
@@ -819,21 +845,23 @@ class TUI:
self.stdscr.addstr(current_y, 2, f"Notes ({len(notes)}):", curses.A_UNDERLINE)
current_y += 1
# Just show last N notes that fit
# Calculate available height for notes list
list_h = self.content_h - (current_y - 2) # Adjust for dynamic header
if list_h < 1:
list_h = 1
start_y = current_y
display_notes = notes[-list_h:] if len(notes) > list_h else notes
# Update scroll to keep selection visible (use full notes list)
if notes:
self._update_scroll(len(notes))
# Update scroll for note selection
if display_notes:
self._update_scroll(len(display_notes))
for i, note in enumerate(display_notes):
# Display notes with proper scrolling
for i in range(list_h):
idx = self.scroll_offset + i
if idx >= len(display_notes):
if idx >= len(notes):
break
note = display_notes[idx]
note = notes[idx]
# Replace newlines with spaces for single-line display
note_content = note.content.replace('\n', ' ').replace('\r', ' ')
display_str = f"- {note_content}"
@@ -1582,22 +1610,75 @@ class TUI:
self.filter_query = ""
self.selected_index = 0
self.scroll_offset = 0
# Validate selected_index against the unfiltered list
self._validate_selection_bounds()
return True
elif key == curses.KEY_ENTER or key in [10, 13]:
self.filter_mode = False
self.selected_index = 0
# Validate selected_index against the filtered list
self._validate_selection_bounds()
self.scroll_offset = 0
return True
elif key == curses.KEY_BACKSPACE or key == 127:
if len(self.filter_query) > 0:
self.filter_query = self.filter_query[:-1]
self.selected_index = 0
self.scroll_offset = 0
elif 32 <= key <= 126:
self.filter_query += chr(key)
self.selected_index = 0
self.scroll_offset = 0
return True
def _validate_selection_bounds(self):
"""Validate and fix selected_index and scroll_offset to ensure they're within bounds"""
max_idx = 0
if self.current_view == "case_list":
filtered = self._get_filtered_list(self.cases, "case_number", "name")
max_idx = len(filtered) - 1
elif self.current_view == "case_detail" and self.active_case:
case_notes = self.active_case.notes
filtered = self._get_filtered_list(self.active_case.evidence, "name", "description")
max_idx = len(filtered) + len(case_notes) - 1
elif self.current_view == "evidence_detail" and self.active_evidence:
notes = self._get_filtered_list(self.active_evidence.notes, "content") if self.filter_query else self.active_evidence.notes
max_idx = len(notes) - 1
elif self.current_view == "tags_list":
tags_to_show = self.current_tags
if self.filter_query:
q = self.filter_query.lower()
tags_to_show = [(tag, count) for tag, count in self.current_tags if q in tag.lower()]
max_idx = len(tags_to_show) - 1
elif self.current_view == "tag_notes_list":
notes_to_show = self._get_filtered_list(self.tag_notes, "content") if self.filter_query else self.tag_notes
max_idx = len(notes_to_show) - 1
elif self.current_view == "ioc_list":
iocs_to_show = self.current_iocs
if self.filter_query:
q = self.filter_query.lower()
iocs_to_show = [(ioc, count, ioc_type) for ioc, count, ioc_type in self.current_iocs
if q in ioc.lower() or q in ioc_type.lower()]
max_idx = len(iocs_to_show) - 1
elif self.current_view == "ioc_notes_list":
notes_to_show = self._get_filtered_list(self.ioc_notes, "content") if self.filter_query else self.ioc_notes
max_idx = len(notes_to_show) - 1
# Ensure max_idx is at least 0
max_idx = max(0, max_idx)
# Fix selected_index if out of bounds
if self.selected_index > max_idx:
self.selected_index = max_idx
if self.selected_index < 0:
self.selected_index = 0
# Fix scroll_offset if out of bounds
if self.scroll_offset > self.selected_index:
self.scroll_offset = max(0, self.selected_index)
if self.scroll_offset < 0:
self.scroll_offset = 0
def _handle_set_active(self):
if self.current_view == "case_list":
filtered = self._get_filtered_list(self.cases, "case_number", "name")
@@ -2477,23 +2558,20 @@ class TUI:
self.show_message("No notes to delete.")
return
# Calculate which note to delete based on display (showing last N filtered notes)
# Get filtered notes (or all notes if no filter)
notes = self._get_filtered_list(self.active_evidence.notes, "content") if self.filter_query else self.active_evidence.notes
list_h = self.content_h - 5 # Adjust for header
display_notes = notes[-list_h:] if len(notes) > list_h else notes
if display_notes:
# User selection is in context of displayed notes
# We need to delete from the full list
if self.selected_index < len(display_notes):
note_to_del = display_notes[self.selected_index]
if notes and self.selected_index < len(notes):
note_to_del = notes[self.selected_index]
# Show preview of note content in confirmation
preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content
if self.dialog_confirm(f"Delete note: '{preview}'?"):
self.active_evidence.notes.remove(note_to_del)
self.storage.save_data()
self.selected_index = 0
self.scroll_offset = 0
# Adjust selected index if needed
if self.selected_index >= len(notes) - 1:
self.selected_index = max(0, len(notes) - 2)
self.scroll_offset = max(0, min(self.scroll_offset, self.selected_index))
self.show_message("Note deleted.")
elif self.current_view == "note_detail":
@@ -2503,21 +2581,28 @@ class TUI:
preview = self.current_note.content[:50] + "..." if len(self.current_note.content) > 50 else self.current_note.content
if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Find and delete the note from its parent (case or evidence)
# Find and delete the note from its parent (case or evidence) using note_id
deleted = False
note_id = self.current_note.note_id
# Check all cases and their evidence for this note
for case in self.cases:
if self.current_note in case.notes:
case.notes.remove(self.current_note)
for note in case.notes:
if note.note_id == note_id:
case.notes.remove(note)
deleted = True
break
if deleted:
break
for ev in case.evidence:
if self.current_note in ev.notes:
ev.notes.remove(self.current_note)
for note in ev.notes:
if note.note_id == note_id:
ev.notes.remove(note)
deleted = True
break
if deleted:
break
if deleted:
break
if deleted:
self.storage.save_data()
@@ -2539,27 +2624,34 @@ class TUI:
note_to_del = notes_to_show[self.selected_index]
preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content
if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Find and delete the note from its parent
# Find and delete the note from its parent using note_id
deleted = False
note_id = note_to_del.note_id
for case in self.cases:
if note_to_del in case.notes:
case.notes.remove(note_to_del)
for note in case.notes:
if note.note_id == note_id:
case.notes.remove(note)
deleted = True
break
if deleted:
break
for ev in case.evidence:
if note_to_del in ev.notes:
ev.notes.remove(note_to_del)
for note in ev.notes:
if note.note_id == note_id:
ev.notes.remove(note)
deleted = True
break
if deleted:
break
if deleted:
break
if deleted:
self.storage.save_data()
# Remove from tag_notes list as well
self.tag_notes.remove(note_to_del)
self.tag_notes = [n for n in self.tag_notes if n.note_id != note_id]
self.selected_index = min(self.selected_index, len(self.tag_notes) - 1) if self.tag_notes else 0
self.scroll_offset = 0
self.scroll_offset = max(0, min(self.scroll_offset, self.selected_index))
self.show_message("Note deleted.")
else:
self.show_message("Error: Note not found.")
@@ -2573,27 +2665,34 @@ class TUI:
note_to_del = notes_to_show[self.selected_index]
preview = note_to_del.content[:50] + "..." if len(note_to_del.content) > 50 else note_to_del.content
if self.dialog_confirm(f"Delete note: '{preview}'?"):
# Find and delete the note from its parent
# Find and delete the note from its parent using note_id
deleted = False
note_id = note_to_del.note_id
for case in self.cases:
if note_to_del in case.notes:
case.notes.remove(note_to_del)
for note in case.notes:
if note.note_id == note_id:
case.notes.remove(note)
deleted = True
break
if deleted:
break
for ev in case.evidence:
if note_to_del in ev.notes:
ev.notes.remove(note_to_del)
for note in ev.notes:
if note.note_id == note_id:
ev.notes.remove(note)
deleted = True
break
if deleted:
break
if deleted:
break
if deleted:
self.storage.save_data()
# Remove from ioc_notes list as well
self.ioc_notes.remove(note_to_del)
self.ioc_notes = [n for n in self.ioc_notes if n.note_id != note_id]
self.selected_index = min(self.selected_index, len(self.ioc_notes) - 1) if self.ioc_notes else 0
self.scroll_offset = 0
self.scroll_offset = max(0, min(self.scroll_offset, self.selected_index))
self.show_message("Note deleted.")
else:
self.show_message("Error: Note not found.")
@@ -3041,7 +3140,109 @@ def run_tui(open_active=False):
open_active: If True, navigate directly to the active case/evidence view
"""
def tui_wrapper(stdscr):
try:
tui = TUI(stdscr)
except RuntimeError as e:
# Handle corrupted JSON data
error_msg = str(e)
if "corrupted" in error_msg.lower():
# Show corruption dialog
stdscr.clear()
h, w = stdscr.getmaxyx()
# Display error message
lines = [
"╔══════════════════════════════════════════════════════════╗",
"║ DATA FILE CORRUPTION DETECTED ║",
"╚══════════════════════════════════════════════════════════╝",
"",
"Your data file appears to be corrupted.",
"",
] + error_msg.split('\n')[:5] + [
"",
"Options:",
" [1] Start fresh (backup already created)",
" [2] Exit and manually recover from backup",
"",
"Press 1 or 2 to continue..."
]
start_y = max(0, (h - len(lines)) // 2)
for i, line in enumerate(lines):
if start_y + i < h - 1:
display_line = line[:w-2] if len(line) > w-2 else line
try:
stdscr.addstr(start_y + i, 2, display_line)
except curses.error:
pass
stdscr.refresh()
# Wait for user choice
while True:
key = stdscr.getch()
if key == ord('1'):
# Start fresh - need to create storage with empty data
from .storage import Storage, DEFAULT_APP_DIR
storage = Storage.__new__(Storage)
storage.app_dir = DEFAULT_APP_DIR
storage.data_file = storage.app_dir / "data.json"
storage.lock_file = storage.app_dir / "app.lock"
storage.lock_manager = None
storage._ensure_app_dir()
# Acquire lock
from .storage import LockManager
storage.lock_manager = LockManager(storage.lock_file)
if not storage.lock_manager.acquire(timeout=5):
raise RuntimeError("Another instance is running")
storage.start_fresh()
# Create TUI with the fresh storage
tui = TUI.__new__(TUI)
tui.stdscr = stdscr
tui.storage = storage
tui.state_manager = StateManager()
tui.current_view = "case_list"
tui.selected_index = 0
tui.scroll_offset = 0
tui.cases = tui.storage.cases
tui.active_case = None
tui.active_evidence = None
tui.current_tags = []
tui.current_tag = None
tui.tag_notes = []
tui.current_note = None
tui.current_iocs = []
tui.current_ioc = None
tui.ioc_notes = []
tui.filter_mode = False
tui.filter_query = ""
tui.flash_message = "Started with fresh data. Backup of corrupted file was created."
tui.flash_time = time.time()
curses.curs_set(0)
curses.start_color()
if curses.has_colors():
curses.init_pair(1, curses.COLOR_BLACK, curses.COLOR_CYAN)
curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)
curses.init_pair(3, curses.COLOR_YELLOW, curses.COLOR_BLACK)
curses.init_pair(4, curses.COLOR_RED, curses.COLOR_BLACK)
curses.init_pair(5, curses.COLOR_CYAN, curses.COLOR_BLACK)
curses.init_pair(6, curses.COLOR_WHITE, curses.COLOR_BLACK)
curses.init_pair(7, curses.COLOR_BLUE, curses.COLOR_BLACK)
curses.init_pair(8, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
curses.init_pair(9, curses.COLOR_RED, curses.COLOR_CYAN)
curses.init_pair(10, curses.COLOR_YELLOW, curses.COLOR_CYAN)
tui.height, tui.width = stdscr.getmaxyx()
active_state = tui.state_manager.get_active()
tui.global_active_case_id = active_state.get("case_id")
tui.global_active_evidence_id = active_state.get("evidence_id")
break
elif key == ord('2') or key == ord('q'):
# Exit
return
else:
# Re-raise if it's a different RuntimeError
raise
# If requested, navigate to active case/evidence
if open_active: