384 lines
14 KiB
Python
384 lines
14 KiB
Python
"""Unit tests for the database module."""
|
|
import unittest
|
|
import sqlite3
|
|
import tempfile
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# Add the src directory to the path so we can import forensictrails
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
from forensictrails.db.database import (
|
|
get_db_connection,
|
|
validate_database_schema,
|
|
create_fresh_database,
|
|
initialize_database,
|
|
show_db_schema
|
|
)
|
|
|
|
|
|
class TestGetDbConnection(unittest.TestCase):
|
|
"""Test cases for get_db_connection function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
if os.path.exists(self.test_db_path):
|
|
os.remove(self.test_db_path)
|
|
os.rmdir(self.temp_dir)
|
|
|
|
def test_get_db_connection_creates_connection(self):
|
|
"""Test that get_db_connection creates a valid connection."""
|
|
conn = get_db_connection(self.test_db_path)
|
|
self.assertIsInstance(conn, sqlite3.Connection)
|
|
self.assertEqual(conn.row_factory, sqlite3.Row)
|
|
conn.close()
|
|
|
|
def test_get_db_connection_creates_file(self):
|
|
"""Test that get_db_connection creates database file if it doesn't exist."""
|
|
self.assertFalse(os.path.exists(self.test_db_path))
|
|
conn = get_db_connection(self.test_db_path)
|
|
conn.close()
|
|
self.assertTrue(os.path.exists(self.test_db_path))
|
|
|
|
|
|
class TestValidateDatabaseSchema(unittest.TestCase):
|
|
"""Test cases for validate_database_schema function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
|
|
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
if os.path.exists(self.test_db_path):
|
|
os.remove(self.test_db_path)
|
|
os.rmdir(self.temp_dir)
|
|
|
|
def test_validate_empty_database_returns_false(self):
|
|
"""Test that an empty database is invalid."""
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
conn.close()
|
|
|
|
result = validate_database_schema(self.test_db_path)
|
|
self.assertFalse(result)
|
|
|
|
def test_validate_incomplete_database_returns_false(self):
|
|
"""Test that a database with missing tables is invalid."""
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create only some of the required tables
|
|
cursor.execute("""
|
|
CREATE TABLE cases (
|
|
case_id TEXT PRIMARY KEY,
|
|
case_title TEXT NOT NULL
|
|
)
|
|
""")
|
|
cursor.execute("""
|
|
CREATE TABLE notes (
|
|
note_id TEXT PRIMARY KEY,
|
|
case_id TEXT NOT NULL
|
|
)
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
result = validate_database_schema(self.test_db_path)
|
|
self.assertFalse(result)
|
|
|
|
def test_validate_complete_database_returns_true(self):
|
|
"""Test that a database with all required tables is valid."""
|
|
# Create database with full schema
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
result = validate_database_schema(self.test_db_path)
|
|
self.assertTrue(result)
|
|
|
|
def test_validate_nonexistent_database_returns_false(self):
|
|
"""Test that validation of non-existent database returns False."""
|
|
result = validate_database_schema('/nonexistent/path/test.db')
|
|
self.assertFalse(result)
|
|
|
|
|
|
class TestCreateFreshDatabase(unittest.TestCase):
|
|
"""Test cases for create_fresh_database function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
|
|
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
if os.path.exists(self.test_db_path):
|
|
os.remove(self.test_db_path)
|
|
os.rmdir(self.temp_dir)
|
|
|
|
def test_create_fresh_database_creates_file(self):
|
|
"""Test that create_fresh_database creates a database file."""
|
|
self.assertFalse(os.path.exists(self.test_db_path))
|
|
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
self.assertTrue(os.path.exists(self.test_db_path))
|
|
|
|
def test_create_fresh_database_creates_all_tables(self):
|
|
"""Test that create_fresh_database creates all required tables."""
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
|
tables = {row[0] for row in cursor.fetchall()}
|
|
conn.close()
|
|
|
|
required_tables = {
|
|
'cases', 'notes', 'evidence', 'chain_of_custody',
|
|
'attachments', 'question_entries', 'users', 'tasks'
|
|
}
|
|
|
|
self.assertEqual(tables, required_tables)
|
|
|
|
def test_create_fresh_database_returns_path(self):
|
|
"""Test that create_fresh_database returns the database path."""
|
|
result = create_fresh_database(self.test_db_path, self.schema_path)
|
|
self.assertEqual(result, self.test_db_path)
|
|
|
|
def test_create_fresh_database_on_clean_path(self):
|
|
"""Test that create_fresh_database works correctly on a clean database path."""
|
|
# Ensure no database exists
|
|
self.assertFalse(os.path.exists(self.test_db_path))
|
|
|
|
# Create fresh database
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
# Verify all tables exist
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
|
tables = {row[0] for row in cursor.fetchall()}
|
|
conn.close()
|
|
|
|
required_tables = {
|
|
'cases', 'notes', 'evidence', 'chain_of_custody',
|
|
'attachments', 'question_entries', 'users', 'tasks'
|
|
}
|
|
self.assertEqual(tables, required_tables)
|
|
|
|
|
|
class TestInitializeDatabase(unittest.TestCase):
|
|
"""Test cases for initialize_database function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
|
|
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
if os.path.exists(self.test_db_path):
|
|
os.remove(self.test_db_path)
|
|
os.rmdir(self.temp_dir)
|
|
|
|
@patch('forensictrails.db.database.config')
|
|
def test_initialize_database_creates_new_database(self, mock_config):
|
|
"""Test that initialize_database creates a new database if none exists."""
|
|
mock_config.database_path = self.test_db_path
|
|
mock_config.database_template = 'schema.sql'
|
|
mock_config.log_level = 'INFO'
|
|
|
|
self.assertFalse(os.path.exists(self.test_db_path))
|
|
|
|
initialize_database(self.test_db_path)
|
|
|
|
self.assertTrue(os.path.exists(self.test_db_path))
|
|
self.assertTrue(validate_database_schema(self.test_db_path))
|
|
|
|
@patch('forensictrails.db.database.config')
|
|
def test_initialize_database_keeps_valid_database(self, mock_config):
|
|
"""Test that initialize_database keeps a valid existing database."""
|
|
mock_config.database_path = self.test_db_path
|
|
mock_config.database_template = 'schema.sql'
|
|
mock_config.log_level = 'INFO'
|
|
|
|
# Create a valid database
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
# Add some data
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
INSERT INTO cases (case_id, case_title, investigator)
|
|
VALUES ('TEST-001', 'Test Case', 'Test Investigator')
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Initialize again
|
|
initialize_database(self.test_db_path)
|
|
|
|
# Verify data still exists
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT case_id FROM cases WHERE case_id = 'TEST-001'")
|
|
result = cursor.fetchone()
|
|
conn.close()
|
|
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result[0], 'TEST-001')
|
|
|
|
@patch('forensictrails.db.database.config')
|
|
def test_initialize_database_recreates_invalid_database(self, mock_config):
|
|
"""Test that initialize_database recreates an invalid database."""
|
|
mock_config.database_path = self.test_db_path
|
|
mock_config.database_template = 'schema.sql'
|
|
mock_config.log_level = 'INFO'
|
|
|
|
# Create an incomplete database
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("CREATE TABLE cases (case_id TEXT PRIMARY KEY)")
|
|
cursor.execute("INSERT INTO cases VALUES ('TEST-001')")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Initialize
|
|
initialize_database(self.test_db_path)
|
|
|
|
# Verify database is now valid and old data is gone
|
|
self.assertTrue(validate_database_schema(self.test_db_path))
|
|
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT case_id FROM cases WHERE case_id = 'TEST-001'")
|
|
result = cursor.fetchone()
|
|
conn.close()
|
|
|
|
self.assertIsNone(result)
|
|
|
|
|
|
class TestShowDbSchema(unittest.TestCase):
|
|
"""Test cases for show_db_schema function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
|
|
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
if os.path.exists(self.test_db_path):
|
|
os.remove(self.test_db_path)
|
|
os.rmdir(self.temp_dir)
|
|
|
|
@patch('forensictrails.db.database.logging')
|
|
def test_show_db_schema_logs_tables(self, mock_logging):
|
|
"""Test that show_db_schema logs all table information."""
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
show_db_schema(self.test_db_path)
|
|
|
|
# Verify that logging.debug was called
|
|
self.assertTrue(mock_logging.debug.called)
|
|
|
|
# Check that it was called for each table (8 tables + 1 header message)
|
|
# Should be at least 9 calls (header + 8 tables)
|
|
self.assertGreaterEqual(mock_logging.debug.call_count, 9)
|
|
|
|
def test_show_db_schema_doesnt_raise_exception(self):
|
|
"""Test that show_db_schema handles execution without raising exceptions."""
|
|
create_fresh_database(self.test_db_path, self.schema_path)
|
|
|
|
try:
|
|
show_db_schema(self.test_db_path)
|
|
except Exception as e:
|
|
self.fail(f"show_db_schema raised an exception: {e}")
|
|
|
|
|
|
class TestDatabaseIntegration(unittest.TestCase):
|
|
"""Integration tests for the database module."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.test_db_path = os.path.join(self.temp_dir, 'test.db')
|
|
self.schema_path = Path(__file__).parent.parent / 'src' / 'forensictrails' / 'db' / 'schema.sql'
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
if os.path.exists(self.test_db_path):
|
|
os.remove(self.test_db_path)
|
|
os.rmdir(self.temp_dir)
|
|
|
|
@patch('forensictrails.db.database.config')
|
|
def test_full_database_lifecycle(self, mock_config):
|
|
"""Test complete database lifecycle: create, use, invalidate, recreate."""
|
|
mock_config.database_path = self.test_db_path
|
|
mock_config.database_template = 'schema.sql'
|
|
mock_config.log_level = 'INFO'
|
|
|
|
# Step 1: Initialize new database
|
|
initialize_database(self.test_db_path)
|
|
self.assertTrue(os.path.exists(self.test_db_path))
|
|
self.assertTrue(validate_database_schema(self.test_db_path))
|
|
|
|
# Step 2: Add some data
|
|
conn = get_db_connection(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
INSERT INTO cases (case_id, case_title, investigator)
|
|
VALUES ('CASE-001', 'Murder Investigation', 'Detective Smith')
|
|
""")
|
|
cursor.execute("""
|
|
INSERT INTO users (user_id, username, full_name)
|
|
VALUES ('USER-001', 'dsmith', 'Detective Smith')
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Step 3: Verify data exists
|
|
conn = get_db_connection(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT case_title FROM cases WHERE case_id = 'CASE-001'")
|
|
result = cursor.fetchone()
|
|
self.assertEqual(result['case_title'], 'Murder Investigation')
|
|
conn.close()
|
|
|
|
# Step 4: Corrupt database (remove a required table)
|
|
conn = sqlite3.connect(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("DROP TABLE users")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Step 5: Verify database is now invalid
|
|
self.assertFalse(validate_database_schema(self.test_db_path))
|
|
|
|
# Step 6: Re-initialize (should recreate)
|
|
initialize_database(self.test_db_path)
|
|
|
|
# Step 7: Verify database is valid again and old data is gone
|
|
self.assertTrue(validate_database_schema(self.test_db_path))
|
|
conn = get_db_connection(self.test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT case_id FROM cases WHERE case_id = 'CASE-001'")
|
|
result = cursor.fetchone()
|
|
self.assertIsNone(result)
|
|
conn.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|