diff --git a/pilot/database/database.py b/pilot/database/database.py index 9ec15a7..eb5020c 100644 --- a/pilot/database/database.py +++ b/pilot/database/database.py @@ -4,8 +4,10 @@ from termcolor import colored from functools import reduce import operator import psycopg2 +import os from const.common import PROMPT_DATA_TO_IGNORE from logger.logger import logger +from psycopg2.extensions import quote_ident from utils.utils import hash_data from database.models.components.base_models import database @@ -23,7 +25,12 @@ from database.models.file_snapshot import FileSnapshot from database.models.command_runs import CommandRuns from database.models.user_inputs import UserInputs from database.models.files import File -from const import db + +DB_NAME = os.getenv("DB_NAME") +DB_HOST = os.getenv("DB_HOST") +DB_PORT = os.getenv("DB_PORT") +DB_USER = os.getenv("DB_USER") +DB_PASSWORD = os.getenv("DB_PASSWORD") def save_user(user_id, email, password): @@ -385,16 +392,23 @@ def database_exists(): def create_database(): + # Connect to the default 'postgres' database to create a new database conn = psycopg2.connect( - dbname="postgres", # default database where we can execute the create database command - user=db.DB_USER, - password=db.DB_PASSWORD, - host=db.DB_HOST, - port=db.DB_PORT + dbname='postgres', + user=DB_USER, + password=DB_PASSWORD, + host=DB_HOST, + port=DB_PORT ) conn.autocommit = True cursor = conn.cursor() - cursor.execute(f"CREATE DATABASE {db.DB_NAME}") + + # Safely quote the database name + safe_db_name = quote_ident(DB_NAME, conn) + + # Use the safely quoted database name in the SQL query + cursor.execute(f"CREATE DATABASE {safe_db_name}") + cursor.close() conn.close()