forensic-trails/tests/test_database.py
overcuriousity 86359ec850 updates
2025-10-08 21:49:39 +02:00

384 lines
14 KiB
Python

"""Unit tests for the database module."""
import unittest
import sqlite3
import tempfile
import os
from pathlib import Path
from unittest.mock import patch, MagicMock
# Add the src directory to the path so we can import forensictrails
import sys
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from forensictrails.db.database import (
get_db_connection,
validate_database_schema,
create_fresh_database,
initialize_database,
show_db_schema
)
class TestGetDbConnection(unittest.TestCase):
"""Test cases for get_db_connection function."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.test_db_path):
os.remove(self.test_db_path)
os.rmdir(self.temp_dir)
def test_get_db_connection_creates_connection(self):
"""Test that get_db_connection creates a valid connection."""
conn = get_db_connection(self.test_db_path)
self.assertIsInstance(conn, sqlite3.Connection)
self.assertEqual(conn.row_factory, sqlite3.Row)
conn.close()
def test_get_db_connection_creates_file(self):
"""Test that get_db_connection creates database file if it doesn't exist."""
self.assertFalse(os.path.exists(self.test_db_path))
conn = get_db_connection(self.test_db_path)
conn.close()
self.assertTrue(os.path.exists(self.test_db_path))
class TestValidateDatabaseSchema(unittest.TestCase):
"""Test cases for validate_database_schema function."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.test_db_path):
os.remove(self.test_db_path)
os.rmdir(self.temp_dir)
def test_validate_empty_database_returns_false(self):
"""Test that an empty database is invalid."""
conn = sqlite3.connect(self.test_db_path)
conn.close()
result = validate_database_schema(self.test_db_path)
self.assertFalse(result)
def test_validate_incomplete_database_returns_false(self):
"""Test that a database with missing tables is invalid."""
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
# Create only some of the required tables
cursor.execute("""
CREATE TABLE cases (
case_id TEXT PRIMARY KEY,
case_title TEXT NOT NULL
)
""")
cursor.execute("""
CREATE TABLE notes (
note_id TEXT PRIMARY KEY,
case_id TEXT NOT NULL
)
""")
conn.commit()
conn.close()
result = validate_database_schema(self.test_db_path)
self.assertFalse(result)
def test_validate_complete_database_returns_true(self):
"""Test that a database with all required tables is valid."""
# Create database with full schema
create_fresh_database(self.test_db_path, self.schema_path)
result = validate_database_schema(self.test_db_path)
self.assertTrue(result)
def test_validate_nonexistent_database_returns_false(self):
"""Test that validation of non-existent database returns False."""
result = validate_database_schema('/nonexistent/path/test.db')
self.assertFalse(result)
class TestCreateFreshDatabase(unittest.TestCase):
"""Test cases for create_fresh_database function."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.test_db_path):
os.remove(self.test_db_path)
os.rmdir(self.temp_dir)
def test_create_fresh_database_creates_file(self):
"""Test that create_fresh_database creates a database file."""
self.assertFalse(os.path.exists(self.test_db_path))
create_fresh_database(self.test_db_path, self.schema_path)
self.assertTrue(os.path.exists(self.test_db_path))
def test_create_fresh_database_creates_all_tables(self):
"""Test that create_fresh_database creates all required tables."""
create_fresh_database(self.test_db_path, self.schema_path)
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = {row[0] for row in cursor.fetchall()}
conn.close()
required_tables = {
'cases', 'notes', 'evidence', 'chain_of_custody',
'attachments', 'question_entries', 'users', 'tasks'
}
self.assertEqual(tables, required_tables)
def test_create_fresh_database_returns_path(self):
"""Test that create_fresh_database returns the database path."""
result = create_fresh_database(self.test_db_path, self.schema_path)
self.assertEqual(result, self.test_db_path)
def test_create_fresh_database_on_clean_path(self):
"""Test that create_fresh_database works correctly on a clean database path."""
# Ensure no database exists
self.assertFalse(os.path.exists(self.test_db_path))
# Create fresh database
create_fresh_database(self.test_db_path, self.schema_path)
# Verify all tables exist
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = {row[0] for row in cursor.fetchall()}
conn.close()
required_tables = {
'cases', 'notes', 'evidence', 'chain_of_custody',
'attachments', 'question_entries', 'users', 'tasks'
}
self.assertEqual(tables, required_tables)
class TestInitializeDatabase(unittest.TestCase):
"""Test cases for initialize_database function."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.test_db_path):
os.remove(self.test_db_path)
os.rmdir(self.temp_dir)
@patch('forensictrails.db.database.config')
def test_initialize_database_creates_new_database(self, mock_config):
"""Test that initialize_database creates a new database if none exists."""
mock_config.database_path = self.test_db_path
mock_config.database_template = 'schema.sql'
mock_config.log_level = 'INFO'
self.assertFalse(os.path.exists(self.test_db_path))
initialize_database(self.test_db_path)
self.assertTrue(os.path.exists(self.test_db_path))
self.assertTrue(validate_database_schema(self.test_db_path))
@patch('forensictrails.db.database.config')
def test_initialize_database_keeps_valid_database(self, mock_config):
"""Test that initialize_database keeps a valid existing database."""
mock_config.database_path = self.test_db_path
mock_config.database_template = 'schema.sql'
mock_config.log_level = 'INFO'
# Create a valid database
create_fresh_database(self.test_db_path, self.schema_path)
# Add some data
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO cases (case_id, case_title, investigator)
VALUES ('TEST-001', 'Test Case', 'Test Investigator')
""")
conn.commit()
conn.close()
# Initialize again
initialize_database(self.test_db_path)
# Verify data still exists
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("SELECT case_id FROM cases WHERE case_id = 'TEST-001'")
result = cursor.fetchone()
conn.close()
self.assertIsNotNone(result)
self.assertEqual(result[0], 'TEST-001')
@patch('forensictrails.db.database.config')
def test_initialize_database_recreates_invalid_database(self, mock_config):
"""Test that initialize_database recreates an invalid database."""
mock_config.database_path = self.test_db_path
mock_config.database_template = 'schema.sql'
mock_config.log_level = 'INFO'
# Create an incomplete database
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("CREATE TABLE cases (case_id TEXT PRIMARY KEY)")
cursor.execute("INSERT INTO cases VALUES ('TEST-001')")
conn.commit()
conn.close()
# Initialize
initialize_database(self.test_db_path)
# Verify database is now valid and old data is gone
self.assertTrue(validate_database_schema(self.test_db_path))
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("SELECT case_id FROM cases WHERE case_id = 'TEST-001'")
result = cursor.fetchone()
conn.close()
self.assertIsNone(result)
class TestShowDbSchema(unittest.TestCase):
"""Test cases for show_db_schema function."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.test_db_path):
os.remove(self.test_db_path)
os.rmdir(self.temp_dir)
@patch('forensictrails.db.database.logging')
def test_show_db_schema_logs_tables(self, mock_logging):
"""Test that show_db_schema logs all table information."""
create_fresh_database(self.test_db_path, self.schema_path)
show_db_schema(self.test_db_path)
# Verify that logging.debug was called
self.assertTrue(mock_logging.debug.called)
# Check that it was called for each table (8 tables + 1 header message)
# Should be at least 9 calls (header + 8 tables)
self.assertGreaterEqual(mock_logging.debug.call_count, 9)
def test_show_db_schema_doesnt_raise_exception(self):
"""Test that show_db_schema handles execution without raising exceptions."""
create_fresh_database(self.test_db_path, self.schema_path)
try:
show_db_schema(self.test_db_path)
except Exception as e:
self.fail(f"show_db_schema raised an exception: {e}")
class TestDatabaseIntegration(unittest.TestCase):
"""Integration tests for the database module."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.test_db_path):
os.remove(self.test_db_path)
os.rmdir(self.temp_dir)
@patch('forensictrails.db.database.config')
def test_full_database_lifecycle(self, mock_config):
"""Test complete database lifecycle: create, use, invalidate, recreate."""
mock_config.database_path = self.test_db_path
mock_config.database_template = 'schema.sql'
mock_config.log_level = 'INFO'
# Step 1: Initialize new database
initialize_database(self.test_db_path)
self.assertTrue(os.path.exists(self.test_db_path))
self.assertTrue(validate_database_schema(self.test_db_path))
# Step 2: Add some data
conn = get_db_connection(self.test_db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO cases (case_id, case_title, investigator)
VALUES ('CASE-001', 'Murder Investigation', 'Detective Smith')
""")
cursor.execute("""
INSERT INTO users (user_id, username, full_name)
VALUES ('USER-001', 'dsmith', 'Detective Smith')
""")
conn.commit()
conn.close()
# Step 3: Verify data exists
conn = get_db_connection(self.test_db_path)
cursor = conn.cursor()
cursor.execute("SELECT case_title FROM cases WHERE case_id = 'CASE-001'")
result = cursor.fetchone()
self.assertEqual(result['case_title'], 'Murder Investigation')
conn.close()
# Step 4: Corrupt database (remove a required table)
conn = sqlite3.connect(self.test_db_path)
cursor = conn.cursor()
cursor.execute("DROP TABLE users")
conn.commit()
conn.close()
# Step 5: Verify database is now invalid
self.assertFalse(validate_database_schema(self.test_db_path))
# Step 6: Re-initialize (should recreate)
initialize_database(self.test_db_path)
# Step 7: Verify database is valid again and old data is gone
self.assertTrue(validate_database_schema(self.test_db_path))
conn = get_db_connection(self.test_db_path)
cursor = conn.cursor()
cursor.execute("SELECT case_id FROM cases WHERE case_id = 'CASE-001'")
result = cursor.fetchone()
self.assertIsNone(result)
conn.close()
if __name__ == '__main__':
unittest.main()