384 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			384 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Unit tests for the database module."""
 | 
						|
import unittest
 | 
						|
import sqlite3
 | 
						|
import tempfile
 | 
						|
import os
 | 
						|
from pathlib import Path
 | 
						|
from unittest.mock import patch, MagicMock
 | 
						|
 | 
						|
# Add the src directory to the path so we can import forensictrails
 | 
						|
import sys
 | 
						|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
 | 
						|
 | 
						|
from forensictrails.db.database import (
 | 
						|
    get_db_connection,
 | 
						|
    validate_database_schema,
 | 
						|
    create_fresh_database,
 | 
						|
    initialize_database,
 | 
						|
    show_db_schema
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class TestGetDbConnection(unittest.TestCase):
 | 
						|
    """Test cases for get_db_connection function."""
 | 
						|
    
 | 
						|
    def setUp(self):
 | 
						|
        """Set up test fixtures."""
 | 
						|
        self.temp_dir = tempfile.mkdtemp()
 | 
						|
        self.test_db_path = os.path.join(self.temp_dir, 'test.db')
 | 
						|
    
 | 
						|
    def tearDown(self):
 | 
						|
        """Clean up test fixtures."""
 | 
						|
        if os.path.exists(self.test_db_path):
 | 
						|
            os.remove(self.test_db_path)
 | 
						|
        os.rmdir(self.temp_dir)
 | 
						|
    
 | 
						|
    def test_get_db_connection_creates_connection(self):
 | 
						|
        """Test that get_db_connection creates a valid connection."""
 | 
						|
        conn = get_db_connection(self.test_db_path)
 | 
						|
        self.assertIsInstance(conn, sqlite3.Connection)
 | 
						|
        self.assertEqual(conn.row_factory, sqlite3.Row)
 | 
						|
        conn.close()
 | 
						|
    
 | 
						|
    def test_get_db_connection_creates_file(self):
 | 
						|
        """Test that get_db_connection creates database file if it doesn't exist."""
 | 
						|
        self.assertFalse(os.path.exists(self.test_db_path))
 | 
						|
        conn = get_db_connection(self.test_db_path)
 | 
						|
        conn.close()
 | 
						|
        self.assertTrue(os.path.exists(self.test_db_path))
 | 
						|
 | 
						|
 | 
						|
class TestValidateDatabaseSchema(unittest.TestCase):
 | 
						|
    """Test cases for validate_database_schema function."""
 | 
						|
    
 | 
						|
    def setUp(self):
 | 
						|
        """Set up test fixtures."""
 | 
						|
        self.temp_dir = tempfile.mkdtemp()
 | 
						|
        self.test_db_path = os.path.join(self.temp_dir, 'test.db')
 | 
						|
        self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
 | 
						|
    
 | 
						|
    def tearDown(self):
 | 
						|
        """Clean up test fixtures."""
 | 
						|
        if os.path.exists(self.test_db_path):
 | 
						|
            os.remove(self.test_db_path)
 | 
						|
        os.rmdir(self.temp_dir)
 | 
						|
    
 | 
						|
    def test_validate_empty_database_returns_false(self):
 | 
						|
        """Test that an empty database is invalid."""
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        result = validate_database_schema(self.test_db_path)
 | 
						|
        self.assertFalse(result)
 | 
						|
    
 | 
						|
    def test_validate_incomplete_database_returns_false(self):
 | 
						|
        """Test that a database with missing tables is invalid."""
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        
 | 
						|
        # Create only some of the required tables
 | 
						|
        cursor.execute("""
 | 
						|
            CREATE TABLE cases (
 | 
						|
                case_id TEXT PRIMARY KEY,
 | 
						|
                case_title TEXT NOT NULL
 | 
						|
            )
 | 
						|
        """)
 | 
						|
        cursor.execute("""
 | 
						|
            CREATE TABLE notes (
 | 
						|
                note_id TEXT PRIMARY KEY,
 | 
						|
                case_id TEXT NOT NULL
 | 
						|
            )
 | 
						|
        """)
 | 
						|
        conn.commit()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        result = validate_database_schema(self.test_db_path)
 | 
						|
        self.assertFalse(result)
 | 
						|
    
 | 
						|
    def test_validate_complete_database_returns_true(self):
 | 
						|
        """Test that a database with all required tables is valid."""
 | 
						|
        # Create database with full schema
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        result = validate_database_schema(self.test_db_path)
 | 
						|
        self.assertTrue(result)
 | 
						|
    
 | 
						|
    def test_validate_nonexistent_database_returns_false(self):
 | 
						|
        """Test that validation of non-existent database returns False."""
 | 
						|
        result = validate_database_schema('/nonexistent/path/test.db')
 | 
						|
        self.assertFalse(result)
 | 
						|
 | 
						|
 | 
						|
class TestCreateFreshDatabase(unittest.TestCase):
 | 
						|
    """Test cases for create_fresh_database function."""
 | 
						|
    
 | 
						|
    def setUp(self):
 | 
						|
        """Set up test fixtures."""
 | 
						|
        self.temp_dir = tempfile.mkdtemp()
 | 
						|
        self.test_db_path = os.path.join(self.temp_dir, 'test.db')
 | 
						|
        self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
 | 
						|
    
 | 
						|
    def tearDown(self):
 | 
						|
        """Clean up test fixtures."""
 | 
						|
        if os.path.exists(self.test_db_path):
 | 
						|
            os.remove(self.test_db_path)
 | 
						|
        os.rmdir(self.temp_dir)
 | 
						|
    
 | 
						|
    def test_create_fresh_database_creates_file(self):
 | 
						|
        """Test that create_fresh_database creates a database file."""
 | 
						|
        self.assertFalse(os.path.exists(self.test_db_path))
 | 
						|
        
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        self.assertTrue(os.path.exists(self.test_db_path))
 | 
						|
    
 | 
						|
    def test_create_fresh_database_creates_all_tables(self):
 | 
						|
        """Test that create_fresh_database creates all required tables."""
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
 | 
						|
        tables = {row[0] for row in cursor.fetchall()}
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        required_tables = {
 | 
						|
            'cases', 'notes', 'evidence', 'chain_of_custody',
 | 
						|
            'attachments', 'question_entries', 'users', 'tasks'
 | 
						|
        }
 | 
						|
        
 | 
						|
        self.assertEqual(tables, required_tables)
 | 
						|
    
 | 
						|
    def test_create_fresh_database_returns_path(self):
 | 
						|
        """Test that create_fresh_database returns the database path."""
 | 
						|
        result = create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        self.assertEqual(result, self.test_db_path)
 | 
						|
    
 | 
						|
    def test_create_fresh_database_on_clean_path(self):
 | 
						|
        """Test that create_fresh_database works correctly on a clean database path."""
 | 
						|
        # Ensure no database exists
 | 
						|
        self.assertFalse(os.path.exists(self.test_db_path))
 | 
						|
        
 | 
						|
        # Create fresh database
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        # Verify all tables exist
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
 | 
						|
        tables = {row[0] for row in cursor.fetchall()}
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        required_tables = {
 | 
						|
            'cases', 'notes', 'evidence', 'chain_of_custody',
 | 
						|
            'attachments', 'question_entries', 'users', 'tasks'
 | 
						|
        }
 | 
						|
        self.assertEqual(tables, required_tables)
 | 
						|
 | 
						|
 | 
						|
class TestInitializeDatabase(unittest.TestCase):
 | 
						|
    """Test cases for initialize_database function."""
 | 
						|
    
 | 
						|
    def setUp(self):
 | 
						|
        """Set up test fixtures."""
 | 
						|
        self.temp_dir = tempfile.mkdtemp()
 | 
						|
        self.test_db_path = os.path.join(self.temp_dir, 'test.db')
 | 
						|
        self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
 | 
						|
    
 | 
						|
    def tearDown(self):
 | 
						|
        """Clean up test fixtures."""
 | 
						|
        if os.path.exists(self.test_db_path):
 | 
						|
            os.remove(self.test_db_path)
 | 
						|
        os.rmdir(self.temp_dir)
 | 
						|
    
 | 
						|
    @patch('forensictrails.db.database.config')
 | 
						|
    def test_initialize_database_creates_new_database(self, mock_config):
 | 
						|
        """Test that initialize_database creates a new database if none exists."""
 | 
						|
        mock_config.database_path = self.test_db_path
 | 
						|
        mock_config.database_template = 'schema.sql'
 | 
						|
        mock_config.log_level = 'INFO'
 | 
						|
        
 | 
						|
        self.assertFalse(os.path.exists(self.test_db_path))
 | 
						|
        
 | 
						|
        initialize_database(self.test_db_path)
 | 
						|
        
 | 
						|
        self.assertTrue(os.path.exists(self.test_db_path))
 | 
						|
        self.assertTrue(validate_database_schema(self.test_db_path))
 | 
						|
    
 | 
						|
    @patch('forensictrails.db.database.config')
 | 
						|
    def test_initialize_database_keeps_valid_database(self, mock_config):
 | 
						|
        """Test that initialize_database keeps a valid existing database."""
 | 
						|
        mock_config.database_path = self.test_db_path
 | 
						|
        mock_config.database_template = 'schema.sql'
 | 
						|
        mock_config.log_level = 'INFO'
 | 
						|
        
 | 
						|
        # Create a valid database
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        # Add some data
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("""
 | 
						|
            INSERT INTO cases (case_id, case_title, investigator)
 | 
						|
            VALUES ('TEST-001', 'Test Case', 'Test Investigator')
 | 
						|
        """)
 | 
						|
        conn.commit()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        # Initialize again
 | 
						|
        initialize_database(self.test_db_path)
 | 
						|
        
 | 
						|
        # Verify data still exists
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("SELECT case_id FROM cases WHERE case_id = 'TEST-001'")
 | 
						|
        result = cursor.fetchone()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        self.assertIsNotNone(result)
 | 
						|
        self.assertEqual(result[0], 'TEST-001')
 | 
						|
    
 | 
						|
    @patch('forensictrails.db.database.config')
 | 
						|
    def test_initialize_database_recreates_invalid_database(self, mock_config):
 | 
						|
        """Test that initialize_database recreates an invalid database."""
 | 
						|
        mock_config.database_path = self.test_db_path
 | 
						|
        mock_config.database_template = 'schema.sql'
 | 
						|
        mock_config.log_level = 'INFO'
 | 
						|
        
 | 
						|
        # Create an incomplete database
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("CREATE TABLE cases (case_id TEXT PRIMARY KEY)")
 | 
						|
        cursor.execute("INSERT INTO cases VALUES ('TEST-001')")
 | 
						|
        conn.commit()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        # Initialize
 | 
						|
        initialize_database(self.test_db_path)
 | 
						|
        
 | 
						|
        # Verify database is now valid and old data is gone
 | 
						|
        self.assertTrue(validate_database_schema(self.test_db_path))
 | 
						|
        
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("SELECT case_id FROM cases WHERE case_id = 'TEST-001'")
 | 
						|
        result = cursor.fetchone()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        self.assertIsNone(result)
 | 
						|
 | 
						|
 | 
						|
class TestShowDbSchema(unittest.TestCase):
 | 
						|
    """Test cases for show_db_schema function."""
 | 
						|
    
 | 
						|
    def setUp(self):
 | 
						|
        """Set up test fixtures."""
 | 
						|
        self.temp_dir = tempfile.mkdtemp()
 | 
						|
        self.test_db_path = os.path.join(self.temp_dir, 'test.db')
 | 
						|
        self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
 | 
						|
    
 | 
						|
    def tearDown(self):
 | 
						|
        """Clean up test fixtures."""
 | 
						|
        if os.path.exists(self.test_db_path):
 | 
						|
            os.remove(self.test_db_path)
 | 
						|
        os.rmdir(self.temp_dir)
 | 
						|
    
 | 
						|
    @patch('forensictrails.db.database.logging')
 | 
						|
    def test_show_db_schema_logs_tables(self, mock_logging):
 | 
						|
        """Test that show_db_schema logs all table information."""
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        show_db_schema(self.test_db_path)
 | 
						|
        
 | 
						|
        # Verify that logging.debug was called
 | 
						|
        self.assertTrue(mock_logging.debug.called)
 | 
						|
        
 | 
						|
        # Check that it was called for each table (8 tables + 1 header message)
 | 
						|
        # Should be at least 9 calls (header + 8 tables)
 | 
						|
        self.assertGreaterEqual(mock_logging.debug.call_count, 9)
 | 
						|
    
 | 
						|
    def test_show_db_schema_doesnt_raise_exception(self):
 | 
						|
        """Test that show_db_schema handles execution without raising exceptions."""
 | 
						|
        create_fresh_database(self.test_db_path, self.schema_path)
 | 
						|
        
 | 
						|
        try:
 | 
						|
            show_db_schema(self.test_db_path)
 | 
						|
        except Exception as e:
 | 
						|
            self.fail(f"show_db_schema raised an exception: {e}")
 | 
						|
 | 
						|
 | 
						|
class TestDatabaseIntegration(unittest.TestCase):
 | 
						|
    """Integration tests for the database module."""
 | 
						|
    
 | 
						|
    def setUp(self):
 | 
						|
        """Set up test fixtures."""
 | 
						|
        self.temp_dir = tempfile.mkdtemp()
 | 
						|
        self.test_db_path = os.path.join(self.temp_dir, 'test.db')
 | 
						|
        self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
 | 
						|
    
 | 
						|
    def tearDown(self):
 | 
						|
        """Clean up test fixtures."""
 | 
						|
        if os.path.exists(self.test_db_path):
 | 
						|
            os.remove(self.test_db_path)
 | 
						|
        os.rmdir(self.temp_dir)
 | 
						|
    
 | 
						|
    @patch('forensictrails.db.database.config')
 | 
						|
    def test_full_database_lifecycle(self, mock_config):
 | 
						|
        """Test complete database lifecycle: create, use, invalidate, recreate."""
 | 
						|
        mock_config.database_path = self.test_db_path
 | 
						|
        mock_config.database_template = 'schema.sql'
 | 
						|
        mock_config.log_level = 'INFO'
 | 
						|
        
 | 
						|
        # Step 1: Initialize new database
 | 
						|
        initialize_database(self.test_db_path)
 | 
						|
        self.assertTrue(os.path.exists(self.test_db_path))
 | 
						|
        self.assertTrue(validate_database_schema(self.test_db_path))
 | 
						|
        
 | 
						|
        # Step 2: Add some data
 | 
						|
        conn = get_db_connection(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("""
 | 
						|
            INSERT INTO cases (case_id, case_title, investigator)
 | 
						|
            VALUES ('CASE-001', 'Murder Investigation', 'Detective Smith')
 | 
						|
        """)
 | 
						|
        cursor.execute("""
 | 
						|
            INSERT INTO users (user_id, username, full_name)
 | 
						|
            VALUES ('USER-001', 'dsmith', 'Detective Smith')
 | 
						|
        """)
 | 
						|
        conn.commit()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        # Step 3: Verify data exists
 | 
						|
        conn = get_db_connection(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("SELECT case_title FROM cases WHERE case_id = 'CASE-001'")
 | 
						|
        result = cursor.fetchone()
 | 
						|
        self.assertEqual(result['case_title'], 'Murder Investigation')
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        # Step 4: Corrupt database (remove a required table)
 | 
						|
        conn = sqlite3.connect(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("DROP TABLE users")
 | 
						|
        conn.commit()
 | 
						|
        conn.close()
 | 
						|
        
 | 
						|
        # Step 5: Verify database is now invalid
 | 
						|
        self.assertFalse(validate_database_schema(self.test_db_path))
 | 
						|
        
 | 
						|
        # Step 6: Re-initialize (should recreate)
 | 
						|
        initialize_database(self.test_db_path)
 | 
						|
        
 | 
						|
        # Step 7: Verify database is valid again and old data is gone
 | 
						|
        self.assertTrue(validate_database_schema(self.test_db_path))
 | 
						|
        conn = get_db_connection(self.test_db_path)
 | 
						|
        cursor = conn.cursor()
 | 
						|
        cursor.execute("SELECT case_id FROM cases WHERE case_id = 'CASE-001'")
 | 
						|
        result = cursor.fetchone()
 | 
						|
        self.assertIsNone(result)
 | 
						|
        conn.close()
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    unittest.main()
 |