bugfix db

This commit is contained in:
overcuriousity 2025-10-08 21:23:02 +02:00
parent c6d0192fa8
commit aa61bfabc1
2 changed files with 21 additions and 34 deletions

View File

@ -1,10 +1,10 @@
import config import utils.config as config
from db.database import get_db_connection from db.database import get_db_connection
import logging import logging
import datetime from datetime import datetime
class CaseManager(case_id=None, case_title=None, investigator=None, classification=None, summary=None): class CaseManager():
def __init__(self, db_path=None, case_id=None, case_title=None, investigator=None, classification=None, summary=None): def __init__(self, db_path=None):
if db_path is None: if db_path is None:
db_path = config.database_path db_path = config.database_path
self.db_path = db_path self.db_path = db_path
@ -12,13 +12,6 @@ class CaseManager(case_id=None, case_title=None, investigator=None, classificati
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
logging.debug(f"Connected to database at {self.db_path}") logging.debug(f"Connected to database at {self.db_path}")
self.case_id = case_id
self.case_title = case_title
self.investigator = investigator
self.classification = classification
self.summary = summary
def create_case(self, case_id, case_title, investigator, classification=None, summary=None): def create_case(self, case_id, case_title, investigator, classification=None, summary=None):
with self.conn: with self.conn:
self.cursor.execute(""" self.cursor.execute("""
@ -27,6 +20,16 @@ class CaseManager(case_id=None, case_title=None, investigator=None, classificati
""", (case_id, case_title, investigator, classification, summary)) """, (case_id, case_title, investigator, classification, summary))
logging.info(f"Created new case with ID: {case_id}") logging.info(f"Created new case with ID: {case_id}")
def get_case(self, case_id):
with self.conn:
self.cursor.execute("SELECT * FROM cases WHERE case_id = ?", (case_id,))
case = self.cursor.fetchone()
if case:
return dict(case)
else:
logging.warning(f"No case found with ID: {case_id}")
return None
def list_cases(self, status=None, search_term=None): def list_cases(self, status=None, search_term=None):
with self.conn: with self.conn:
query = "SELECT * FROM cases WHERE 1=1" query = "SELECT * FROM cases WHERE 1=1"
@ -80,28 +83,12 @@ class CaseManager(case_id=None, case_title=None, investigator=None, classificati
return self.update_case(case_id, status='archived') return self.update_case(case_id, status='archived')
def export_case_db(self, case_id, export_path): def export_case_db(self, case_id, export_path):
with self.conn: # TODO: Implement export functionality
self.cursor.execute("SELECT * FROM cases WHERE case_id = ?", (case_id,)) # should export a .sqlite file with only the data related to the specified case_id
case = self.cursor.fetchone() pass
if not case:
logging.error(f"No case found with ID: {case_id}")
return False
with open(export_path, 'w') as f:
for key in case.keys():
f.write(f"{key}: {case[key]}\n")
logging.info(f"Exported case {case_id} to {export_path}")
return True
def import_case_db(self, import_path): def import_case_db(self, import_path):
with open(import_path, 'r') as f: # TODO: Implement import functionality
try: # should import a .sqlite file and merge its data into the main database
with open(config.database_path, 'w') as db_file: pass
db_file.write(f.read())
get_db_connection(config.database_path)
logging.info(f"Imported case database from {import_path}")
return True
except Exception as e:
logging.error(f"Failed to import case database: {e}")
return False

View File

@ -31,7 +31,7 @@ def create_db_if_not_exists(db_path=None, schema_path=None):
def initialize_database(db_path=None): def initialize_database(db_path=None):
if db_path is None: if db_path is None:
db_path = config.database_path db_path = config.database_path
get_db_connection(db_path) create_db_if_not_exists(db_path)
if config.log_level == 'DEBUG': if config.log_level == 'DEBUG':
show_db_schema(db_path) show_db_schema(db_path)
logging.info(f"Database initialized at {db_path}") logging.info(f"Database initialized at {db_path}")