From 6df4c5be07cf7792e4b88276bf1ddeb6ea8544dc Mon Sep 17 00:00:00 2001 From: thomasabishop Date: Sat, 18 Oct 2025 18:47:04 +0100 Subject: [PATCH] refactor: use context manager for database connection --- src/app.py | 15 ++++++--------- ...database_service.py => database_connection.py} | 11 ++++++++++- 2 files changed, 16 insertions(+), 10 deletions(-) rename src/{services/database_service.py => database_connection.py} (76%) diff --git a/src/app.py b/src/app.py index e3f0107..7a6652e 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,6 @@ import argparse -from services.database_service import DatabaseService +from database_connection import DatabaseConnection from services.parse_file_service import ParseFileService from services.table_service import TableService @@ -29,14 +29,11 @@ def main(): SOURCE_DIRECTORY = args.source TARGET_DIRECTORY = args.target - database_service = DatabaseService("eolas", TARGET_DIRECTORY) - database_connection = database_service.connect() - table_service = TableService(database_connection) - parse_file_service = ParseFileService(SOURCE_DIRECTORY) - - entries = parse_file_service.parse_source_directory() - table_service.populate_tables(entries) - database_service.disconnect() + with DatabaseConnection("eolas", TARGET_DIRECTORY) as conn: + table_service = TableService(conn) + parse_file_service = ParseFileService(SOURCE_DIRECTORY) + entries = parse_file_service.parse_source_directory() + table_service.populate_tables(entries) if __name__ == "__main__": diff --git a/src/services/database_service.py b/src/database_connection.py similarity index 76% rename from src/services/database_service.py rename to src/database_connection.py index bf1c10a..a0cb172 100644 --- a/src/services/database_service.py +++ b/src/database_connection.py @@ -3,7 +3,7 @@ import sqlite3 from typing import Optional -class DatabaseService: +class DatabaseConnection: def __init__(self, db_name, db_path): self.db_name = db_name self.db_path = db_path @@ -31,3 +31,12 @@ class DatabaseService: self.connection = None except Exception as e: raise Exception(f"ERROR Problem disconnecting from database: {e}") + + def __enter__(self) -> sqlite3.Connection: + connection = self.connect() + if connection is None: + raise RuntimeError("Failed to establish database connection") + return connection + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.disconnect()