diff --git a/pilot/database/database.py b/pilot/database/database.py index 71bcf3a..9ec15a7 100644 --- a/pilot/database/database.py +++ b/pilot/database/database.py @@ -3,6 +3,7 @@ from peewee import * from termcolor import colored from functools import reduce import operator +import psycopg2 from const.common import PROMPT_DATA_TO_IGNORE from logger.logger import logger @@ -22,6 +23,7 @@ from database.models.file_snapshot import FileSnapshot from database.models.command_runs import CommandRuns from database.models.user_inputs import UserInputs from database.models.files import File +from const import db def save_user(user_id, email, password): @@ -36,7 +38,6 @@ def save_user(user_id, email, password): return User.create(id=user_id, email=email, password=password) - def get_user(user_id=None, email=None): if not user_id and not email: raise ValueError("Either user_id or email must be provided") @@ -179,11 +180,11 @@ def hash_and_save_step(Model, app_id, hash_data_args, data_fields, message): try: inserted_id = (Model - .insert(**data_to_insert) - .on_conflict(conflict_target=[Model.app, Model.hash_id], - preserve=fields_to_preserve, - update=data_fields) - .execute()) + .insert(**data_to_insert) + .on_conflict(conflict_target=[Model.app, Model.hash_id], + preserve=fields_to_preserve, + update=data_fields) + .execute()) record = Model.get_by_id(inserted_id) logger.debug(colored(f"{message} with id {record.id}", "yellow")) @@ -196,7 +197,8 @@ def hash_and_save_step(Model, app_id, hash_data_args, data_fields, message): def save_development_step(project, prompt_path, prompt_data, messages, llm_response): hash_data_args = { 'prompt_path': prompt_path, - 'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if k not in PROMPT_DATA_TO_IGNORE}, + 'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if + k not in PROMPT_DATA_TO_IGNORE}, 'llm_req_num': project.llm_req_num } @@ -206,7 +208,8 @@ def save_development_step(project, prompt_path, prompt_data, messages, llm_respo 'previous_step': project.checkpoints['last_development_step'], } - development_step = hash_and_save_step(DevelopmentSteps, project.args['app_id'], hash_data_args, data_fields, "Saved Development Step") + development_step = hash_and_save_step(DevelopmentSteps, project.args['app_id'], hash_data_args, data_fields, + "Saved Development Step") project.checkpoints['last_development_step'] = development_step return development_step @@ -214,12 +217,15 @@ def save_development_step(project, prompt_path, prompt_data, messages, llm_respo def get_development_step_from_hash_id(project, prompt_path, prompt_data, llm_req_num): data_to_hash = { 'prompt_path': prompt_path, - 'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if k not in PROMPT_DATA_TO_IGNORE}, + 'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if + k not in PROMPT_DATA_TO_IGNORE}, 'llm_req_num': llm_req_num } - development_step = get_db_model_from_hash_id(DevelopmentSteps, project.args['app_id'], project.checkpoints['last_development_step']) + development_step = get_db_model_from_hash_id(DevelopmentSteps, project.args['app_id'], + project.checkpoints['last_development_step']) return development_step + def save_command_run(project, command, cli_response): hash_data_args = { 'command': command, @@ -230,7 +236,8 @@ def save_command_run(project, command, cli_response): 'cli_response': cli_response, 'previous_step': project.checkpoints['last_command_run'], } - command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, "Saved Command Run") + command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, + "Saved Command Run") project.checkpoints['last_command_run'] = command_run return command_run @@ -240,9 +247,11 @@ def get_command_run_from_hash_id(project, command): 'command': command, 'command_runs_count': project.command_runs_count } - command_run = get_db_model_from_hash_id(CommandRuns, project.args['app_id'], project.checkpoints['last_command_run']) + command_run = get_db_model_from_hash_id(CommandRuns, project.args['app_id'], + project.checkpoints['last_command_run']) return command_run + def save_user_input(project, query, user_input): hash_data_args = { 'query': query, @@ -257,6 +266,7 @@ def save_user_input(project, query, user_input): project.checkpoints['last_user_input'] = user_input return user_input + def get_user_input_from_hash_id(project, query): data_to_hash = { 'query': query, @@ -265,11 +275,13 @@ def get_user_input_from_hash_id(project, query): user_input = get_db_model_from_hash_id(UserInputs, project.args['app_id'], project.checkpoints['last_user_input']) return user_input + def delete_all_subsequent_steps(project): delete_subsequent_steps(DevelopmentSteps, project.checkpoints['last_development_step']) delete_subsequent_steps(CommandRuns, project.checkpoints['last_command_run']) delete_subsequent_steps(UserInputs, project.checkpoints['last_user_input']) + def delete_subsequent_steps(model, step): if step is None: return @@ -280,6 +292,7 @@ def delete_subsequent_steps(model, step): delete_subsequent_steps(model, subsequent_step) subsequent_step.delete_instance() + def get_all_connected_steps(step, previous_step_field_name): """Recursively get all steps connected to the given step.""" connected_steps = [step] @@ -289,11 +302,13 @@ def get_all_connected_steps(step, previous_step_field_name): prev_step = getattr(prev_step, previous_step_field_name) return connected_steps + def delete_all_app_development_data(app): models = [DevelopmentSteps, CommandRuns, UserInputs, File, FileSnapshot] for model in models: model.delete().where(model.app == app).execute() + def delete_unconnected_steps_from(step, previous_step_field_name): if step is None: return @@ -309,13 +324,15 @@ def delete_unconnected_steps_from(step, previous_step_field_name): print(colored(f"Deleting unconnected {step.__class__.__name__} step {unconnected_step.id}", "red")) unconnected_step.delete_instance() + def save_file_description(project, path, name, description): (File.insert(app=project.app, path=path, name=name, description=description) - .on_conflict( - conflict_target=[File.app, File.name, File.path], - preserve=[], - update={'description': description}) - .execute()) + .on_conflict( + conflict_target=[File.app, File.name, File.path], + preserve=[], + update={'description': description}) + .execute()) + def create_tables(): with database: @@ -354,10 +371,45 @@ def drop_tables(): CommandRuns, UserInputs, File, - ]: + ]: database.execute_sql(f'DROP TABLE IF EXISTS "{table._meta.table_name}" CASCADE') +def database_exists(): + try: + database.connect() + database.close() + return True + except Exception: + return False + + +def create_database(): + conn = psycopg2.connect( + dbname="postgres", # default database where we can execute the create database command + user=db.DB_USER, + password=db.DB_PASSWORD, + host=db.DB_HOST, + port=db.DB_PORT + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute(f"CREATE DATABASE {db.DB_NAME}") + cursor.close() + conn.close() + + +def tables_exist(): + tables = [User, App, ProjectDescription, UserStories, UserTasks, Architecture, DevelopmentPlanning, + DevelopmentSteps, EnvironmentSetup, Development, FileSnapshot, CommandRuns, UserInputs, File] + for table in tables: + try: + database.get_tables().index(table._meta.table_name) + except ValueError: + return False + return True + + if __name__ == "__main__": drop_tables() create_tables() diff --git a/pilot/main.py b/pilot/main.py index f51edb9..b483ce7 100644 --- a/pilot/main.py +++ b/pilot/main.py @@ -8,9 +8,18 @@ from helpers.Project import Project from utils.arguments import get_arguments from logger.logger import logger +from database.database import database_exists, create_database, tables_exist, create_tables def init(): + # Check if the "euclid" database exists, if not, create it + if not database_exists(): + create_database() + + # Check if the tables exist, if not, create them + if not tables_exist(): + create_tables() + arguments = get_arguments() logger.info(f"Starting with args: {arguments}")