"""Unit tests for the database module.""" import unittest import sqlite3 import tempfile import os from pathlib import Path from unittest.mock import patch, MagicMock # Add the src directory to the path so we can import forensictrails import sys sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) from forensictrails.db.database import ( get_db_connection, validate_database_schema, create_fresh_database, initialize_database, show_db_schema ) class TestGetDbConnection(unittest.TestCase): """Test cases for get_db_connection function.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.test_db_path = os.path.join(self.temp_dir, 'test.db') def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.test_db_path): os.remove(self.test_db_path) os.rmdir(self.temp_dir) def test_get_db_connection_creates_connection(self): """Test that get_db_connection creates a valid connection.""" conn = get_db_connection(self.test_db_path) self.assertIsInstance(conn, sqlite3.Connection) self.assertEqual(conn.row_factory, sqlite3.Row) conn.close() def test_get_db_connection_creates_file(self): """Test that get_db_connection creates database file if it doesn't exist.""" self.assertFalse(os.path.exists(self.test_db_path)) conn = get_db_connection(self.test_db_path) conn.close() self.assertTrue(os.path.exists(self.test_db_path)) class TestValidateDatabaseSchema(unittest.TestCase): """Test cases for validate_database_schema function.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.test_db_path = os.path.join(self.temp_dir, 'test.db') self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql' def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.test_db_path): os.remove(self.test_db_path) os.rmdir(self.temp_dir) def test_validate_empty_database_returns_false(self): """Test that an empty database is invalid.""" conn = sqlite3.connect(self.test_db_path) conn.close() result = validate_database_schema(self.test_db_path) self.assertFalse(result) def test_validate_incomplete_database_returns_false(self): """Test that a database with missing tables is invalid.""" conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() # Create only some of the required tables cursor.execute(""" CREATE TABLE "case" ( case_id INTEGER PRIMARY KEY, description TEXT ) """) cursor.execute(""" CREATE TABLE note ( note_id INTEGER PRIMARY KEY, case_id INTEGER NOT NULL ) """) conn.commit() conn.close() result = validate_database_schema(self.test_db_path) self.assertFalse(result) def test_validate_complete_database_returns_true(self): """Test that a database with all required tables is valid.""" # Create database with full schema create_fresh_database(self.test_db_path, self.schema_path) result = validate_database_schema(self.test_db_path) self.assertTrue(result) def test_validate_nonexistent_database_returns_false(self): """Test that validation of non-existent database returns False.""" result = validate_database_schema('/nonexistent/path/test.db') self.assertFalse(result) class TestCreateFreshDatabase(unittest.TestCase): """Test cases for create_fresh_database function.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.test_db_path = os.path.join(self.temp_dir, 'test.db') self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql' def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.test_db_path): os.remove(self.test_db_path) os.rmdir(self.temp_dir) def test_create_fresh_database_creates_file(self): """Test that create_fresh_database creates a database file.""" self.assertFalse(os.path.exists(self.test_db_path)) create_fresh_database(self.test_db_path, self.schema_path) self.assertTrue(os.path.exists(self.test_db_path)) def test_create_fresh_database_creates_all_tables(self): """Test that create_fresh_database creates all required tables.""" create_fresh_database(self.test_db_path, self.schema_path) conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = {row[0] for row in cursor.fetchall()} conn.close() required_tables = { 'case', 'note', 'note_revision', 'investigator', 'evidence', 'chain_of_custody', 'question_definition', 'note_question_tag', 'investigator_case' } # sqlite_sequence is automatically created by SQLite for AUTOINCREMENT self.assertTrue(required_tables.issubset(tables)) def test_create_fresh_database_returns_path(self): """Test that create_fresh_database returns the database path.""" result = create_fresh_database(self.test_db_path, self.schema_path) self.assertEqual(result, self.test_db_path) def test_create_fresh_database_on_clean_path(self): """Test that create_fresh_database works correctly on a clean database path.""" # Ensure no database exists self.assertFalse(os.path.exists(self.test_db_path)) # Create fresh database create_fresh_database(self.test_db_path, self.schema_path) # Verify all tables exist conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = {row[0] for row in cursor.fetchall()} conn.close() required_tables = { 'case', 'note', 'note_revision', 'investigator', 'evidence', 'chain_of_custody', 'question_definition', 'note_question_tag', 'investigator_case' } # sqlite_sequence is automatically created by SQLite for AUTOINCREMENT self.assertTrue(required_tables.issubset(tables)) class TestInitializeDatabase(unittest.TestCase): """Test cases for initialize_database function.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.test_db_path = os.path.join(self.temp_dir, 'test.db') self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql' def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.test_db_path): os.remove(self.test_db_path) os.rmdir(self.temp_dir) @patch('forensictrails.db.database.config') def test_initialize_database_creates_new_database(self, mock_config): """Test that initialize_database creates a new database if none exists.""" mock_config.database_path = self.test_db_path mock_config.database_template = 'schema.sql' mock_config.log_level = 'INFO' self.assertFalse(os.path.exists(self.test_db_path)) initialize_database(self.test_db_path) self.assertTrue(os.path.exists(self.test_db_path)) self.assertTrue(validate_database_schema(self.test_db_path)) @patch('forensictrails.db.database.config') def test_initialize_database_keeps_valid_database(self, mock_config): """Test that initialize_database keeps a valid existing database.""" mock_config.database_path = self.test_db_path mock_config.database_template = 'schema.sql' mock_config.log_level = 'INFO' # Create a valid database create_fresh_database(self.test_db_path, self.schema_path) # Add some data conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute(""" INSERT INTO "case" (description, status) VALUES ('Test Case', 'Open') """) conn.commit() case_id = cursor.lastrowid conn.close() # Initialize again initialize_database(self.test_db_path) # Verify data still exists conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute("SELECT case_id FROM \"case\" WHERE case_id = ?", (case_id,)) result = cursor.fetchone() conn.close() self.assertIsNotNone(result) self.assertEqual(result[0], case_id) @patch('forensictrails.db.database.config') def test_initialize_database_recreates_invalid_database(self, mock_config): """Test that initialize_database recreates an invalid database.""" mock_config.database_path = self.test_db_path mock_config.database_template = 'schema.sql' mock_config.log_level = 'INFO' # Create an incomplete database conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute("CREATE TABLE \"case\" (case_id INTEGER PRIMARY KEY)") cursor.execute("INSERT INTO \"case\" (case_id) VALUES (999)") conn.commit() conn.close() # Initialize initialize_database(self.test_db_path) # Verify database is now valid and old data is gone self.assertTrue(validate_database_schema(self.test_db_path)) conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute("SELECT case_id FROM \"case\" WHERE case_id = 999") result = cursor.fetchone() conn.close() self.assertIsNone(result) class TestShowDbSchema(unittest.TestCase): """Test cases for show_db_schema function.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.test_db_path = os.path.join(self.temp_dir, 'test.db') self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql' def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.test_db_path): os.remove(self.test_db_path) os.rmdir(self.temp_dir) @patch('forensictrails.db.database.logging') def test_show_db_schema_logs_tables(self, mock_logging): """Test that show_db_schema logs all table information.""" create_fresh_database(self.test_db_path, self.schema_path) show_db_schema(self.test_db_path) # Verify that logging.debug was called self.assertTrue(mock_logging.debug.called) # Check that it was called for each table (9 tables + 1 header message) # Should be at least 10 calls (header + 9 tables) self.assertGreaterEqual(mock_logging.debug.call_count, 10) def test_show_db_schema_doesnt_raise_exception(self): """Test that show_db_schema handles execution without raising exceptions.""" create_fresh_database(self.test_db_path, self.schema_path) try: show_db_schema(self.test_db_path) except Exception as e: self.fail(f"show_db_schema raised an exception: {e}") class TestDatabaseIntegration(unittest.TestCase): """Integration tests for the database module.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.test_db_path = os.path.join(self.temp_dir, 'test.db') self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql' def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.test_db_path): os.remove(self.test_db_path) os.rmdir(self.temp_dir) @patch('forensictrails.db.database.config') def test_full_database_lifecycle(self, mock_config): """Test complete database lifecycle: create, use, invalidate, recreate.""" mock_config.database_path = self.test_db_path mock_config.database_template = 'schema.sql' mock_config.log_level = 'INFO' # Step 1: Initialize new database initialize_database(self.test_db_path) self.assertTrue(os.path.exists(self.test_db_path)) self.assertTrue(validate_database_schema(self.test_db_path)) # Step 2: Add some data conn = get_db_connection(self.test_db_path) cursor = conn.cursor() cursor.execute(""" INSERT INTO "case" (description, status) VALUES ('Murder Investigation', 'Open') """) case_id = cursor.lastrowid cursor.execute(""" INSERT INTO investigator (name, role) VALUES ('Detective Smith', 'Lead Investigator') """) conn.commit() conn.close() # Step 3: Verify data exists conn = get_db_connection(self.test_db_path) cursor = conn.cursor() cursor.execute("SELECT description FROM \"case\" WHERE case_id = ?", (case_id,)) result = cursor.fetchone() self.assertEqual(result['description'], 'Murder Investigation') conn.close() # Step 4: Corrupt database (remove a required table) conn = sqlite3.connect(self.test_db_path) cursor = conn.cursor() cursor.execute("DROP TABLE investigator") conn.commit() conn.close() # Step 5: Verify database is now invalid self.assertFalse(validate_database_schema(self.test_db_path)) # Step 6: Re-initialize (should recreate) initialize_database(self.test_db_path) # Step 7: Verify database is valid again and old data is gone self.assertTrue(validate_database_schema(self.test_db_path)) conn = get_db_connection(self.test_db_path) cursor = conn.cursor() cursor.execute("SELECT case_id FROM \"case\" WHERE case_id = ?", (case_id,)) result = cursor.fetchone() self.assertIsNone(result) conn.close() if __name__ == '__main__': unittest.main()