check if database exists if not create it, check if tables exist if not create them

This commit is contained in:
LeonOstrez
2023-08-16 19:47:45 +02:00
parent 719d55deef
commit 3c5f0cde48
2 changed files with 79 additions and 18 deletions

View File

@@ -3,6 +3,7 @@ from peewee import *
from termcolor import colored
from functools import reduce
import operator
import psycopg2
from const.common import PROMPT_DATA_TO_IGNORE
from logger.logger import logger
@@ -22,6 +23,7 @@ from database.models.file_snapshot import FileSnapshot
from database.models.command_runs import CommandRuns
from database.models.user_inputs import UserInputs
from database.models.files import File
from const import db
def save_user(user_id, email, password):
@@ -36,7 +38,6 @@ def save_user(user_id, email, password):
return User.create(id=user_id, email=email, password=password)
def get_user(user_id=None, email=None):
if not user_id and not email:
raise ValueError("Either user_id or email must be provided")
@@ -179,11 +180,11 @@ def hash_and_save_step(Model, app_id, hash_data_args, data_fields, message):
try:
inserted_id = (Model
.insert(**data_to_insert)
.on_conflict(conflict_target=[Model.app, Model.hash_id],
preserve=fields_to_preserve,
update=data_fields)
.execute())
.insert(**data_to_insert)
.on_conflict(conflict_target=[Model.app, Model.hash_id],
preserve=fields_to_preserve,
update=data_fields)
.execute())
record = Model.get_by_id(inserted_id)
logger.debug(colored(f"{message} with id {record.id}", "yellow"))
@@ -196,7 +197,8 @@ def hash_and_save_step(Model, app_id, hash_data_args, data_fields, message):
def save_development_step(project, prompt_path, prompt_data, messages, llm_response):
hash_data_args = {
'prompt_path': prompt_path,
'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if k not in PROMPT_DATA_TO_IGNORE},
'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if
k not in PROMPT_DATA_TO_IGNORE},
'llm_req_num': project.llm_req_num
}
@@ -206,7 +208,8 @@ def save_development_step(project, prompt_path, prompt_data, messages, llm_respo
'previous_step': project.checkpoints['last_development_step'],
}
development_step = hash_and_save_step(DevelopmentSteps, project.args['app_id'], hash_data_args, data_fields, "Saved Development Step")
development_step = hash_and_save_step(DevelopmentSteps, project.args['app_id'], hash_data_args, data_fields,
"Saved Development Step")
project.checkpoints['last_development_step'] = development_step
return development_step
@@ -214,12 +217,15 @@ def save_development_step(project, prompt_path, prompt_data, messages, llm_respo
def get_development_step_from_hash_id(project, prompt_path, prompt_data, llm_req_num):
data_to_hash = {
'prompt_path': prompt_path,
'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if k not in PROMPT_DATA_TO_IGNORE},
'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if
k not in PROMPT_DATA_TO_IGNORE},
'llm_req_num': llm_req_num
}
development_step = get_db_model_from_hash_id(DevelopmentSteps, project.args['app_id'], project.checkpoints['last_development_step'])
development_step = get_db_model_from_hash_id(DevelopmentSteps, project.args['app_id'],
project.checkpoints['last_development_step'])
return development_step
def save_command_run(project, command, cli_response):
hash_data_args = {
'command': command,
@@ -230,7 +236,8 @@ def save_command_run(project, command, cli_response):
'cli_response': cli_response,
'previous_step': project.checkpoints['last_command_run'],
}
command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, "Saved Command Run")
command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields,
"Saved Command Run")
project.checkpoints['last_command_run'] = command_run
return command_run
@@ -240,9 +247,11 @@ def get_command_run_from_hash_id(project, command):
'command': command,
'command_runs_count': project.command_runs_count
}
command_run = get_db_model_from_hash_id(CommandRuns, project.args['app_id'], project.checkpoints['last_command_run'])
command_run = get_db_model_from_hash_id(CommandRuns, project.args['app_id'],
project.checkpoints['last_command_run'])
return command_run
def save_user_input(project, query, user_input):
hash_data_args = {
'query': query,
@@ -257,6 +266,7 @@ def save_user_input(project, query, user_input):
project.checkpoints['last_user_input'] = user_input
return user_input
def get_user_input_from_hash_id(project, query):
data_to_hash = {
'query': query,
@@ -265,11 +275,13 @@ def get_user_input_from_hash_id(project, query):
user_input = get_db_model_from_hash_id(UserInputs, project.args['app_id'], project.checkpoints['last_user_input'])
return user_input
def delete_all_subsequent_steps(project):
delete_subsequent_steps(DevelopmentSteps, project.checkpoints['last_development_step'])
delete_subsequent_steps(CommandRuns, project.checkpoints['last_command_run'])
delete_subsequent_steps(UserInputs, project.checkpoints['last_user_input'])
def delete_subsequent_steps(model, step):
if step is None:
return
@@ -280,6 +292,7 @@ def delete_subsequent_steps(model, step):
delete_subsequent_steps(model, subsequent_step)
subsequent_step.delete_instance()
def get_all_connected_steps(step, previous_step_field_name):
"""Recursively get all steps connected to the given step."""
connected_steps = [step]
@@ -289,11 +302,13 @@ def get_all_connected_steps(step, previous_step_field_name):
prev_step = getattr(prev_step, previous_step_field_name)
return connected_steps
def delete_all_app_development_data(app):
models = [DevelopmentSteps, CommandRuns, UserInputs, File, FileSnapshot]
for model in models:
model.delete().where(model.app == app).execute()
def delete_unconnected_steps_from(step, previous_step_field_name):
if step is None:
return
@@ -309,13 +324,15 @@ def delete_unconnected_steps_from(step, previous_step_field_name):
print(colored(f"Deleting unconnected {step.__class__.__name__} step {unconnected_step.id}", "red"))
unconnected_step.delete_instance()
def save_file_description(project, path, name, description):
(File.insert(app=project.app, path=path, name=name, description=description)
.on_conflict(
conflict_target=[File.app, File.name, File.path],
preserve=[],
update={'description': description})
.execute())
.on_conflict(
conflict_target=[File.app, File.name, File.path],
preserve=[],
update={'description': description})
.execute())
def create_tables():
with database:
@@ -354,10 +371,45 @@ def drop_tables():
CommandRuns,
UserInputs,
File,
]:
]:
database.execute_sql(f'DROP TABLE IF EXISTS "{table._meta.table_name}" CASCADE')
def database_exists():
try:
database.connect()
database.close()
return True
except Exception:
return False
def create_database():
conn = psycopg2.connect(
dbname="postgres", # default database where we can execute the create database command
user=db.DB_USER,
password=db.DB_PASSWORD,
host=db.DB_HOST,
port=db.DB_PORT
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute(f"CREATE DATABASE {db.DB_NAME}")
cursor.close()
conn.close()
def tables_exist():
tables = [User, App, ProjectDescription, UserStories, UserTasks, Architecture, DevelopmentPlanning,
DevelopmentSteps, EnvironmentSetup, Development, FileSnapshot, CommandRuns, UserInputs, File]
for table in tables:
try:
database.get_tables().index(table._meta.table_name)
except ValueError:
return False
return True
if __name__ == "__main__":
drop_tables()
create_tables()

View File

@@ -8,9 +8,18 @@ from helpers.Project import Project
from utils.arguments import get_arguments
from logger.logger import logger
from database.database import database_exists, create_database, tables_exist, create_tables
def init():
# Check if the "euclid" database exists, if not, create it
if not database_exists():
create_database()
# Check if the tables exist, if not, create them
if not tables_exist():
create_tables()
arguments = get_arguments()
logger.info(f"Starting with args: {arguments}")