From 01b50ab3e3aa8db7ba5338b6614d42ff30b7e804 Mon Sep 17 00:00:00 2001 From: Zvonimir Sabljic Date: Thu, 10 Aug 2023 08:41:46 +0200 Subject: [PATCH] Refactored saving previous steps and made steps be loaded by looking at the current step as well --- euclid/database/database.py | 48 ++++++++++++++----- euclid/database/models/command_runs.py | 4 +- .../models/components/progress_step.py | 2 +- euclid/database/models/development_steps.py | 4 +- euclid/database/models/file_snapshot.py | 2 +- euclid/database/models/files.py | 2 +- euclid/database/models/user_inputs.py | 4 +- euclid/helpers/AgentConvo.py | 2 + euclid/helpers/Project.py | 7 ++- 9 files changed, 54 insertions(+), 21 deletions(-) diff --git a/euclid/database/database.py b/euclid/database/database.py index 4488d3a..6d00d0b 100644 --- a/euclid/database/database.py +++ b/euclid/database/database.py @@ -154,10 +154,10 @@ def get_progress_steps(app_id, step=None): return steps -def get_db_model_from_hash_id(data_to_hash, model, app_id): +def get_db_model_from_hash_id(data_to_hash, model, app_id, previous_step): hash_id = hash_data(data_to_hash) try: - db_row = model.get((model.hash_id == hash_id) & (model.app == app_id)) + db_row = model.get((model.hash_id == hash_id) & (model.app == app_id) & (model.previous_step == previous_step)) except DoesNotExist: return None return db_row @@ -203,7 +203,7 @@ def save_development_step(project, prompt_path, prompt_data, messages, llm_respo data_fields = { 'messages': messages, 'llm_response': llm_response, - 'previous_dev_step': project.checkpoints['last_development_step'], + 'previous_step': project.checkpoints['last_development_step'], } development_step = hash_and_save_step(DevelopmentSteps, project.args['app_id'], hash_data_args, data_fields, "Saved Development Step") @@ -217,7 +217,7 @@ def get_development_step_from_hash_id(project, prompt_path, prompt_data, llm_req 'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if k not in PROMPT_DATA_TO_IGNORE}, 'llm_req_num': llm_req_num } - development_step = get_db_model_from_hash_id(data_to_hash, DevelopmentSteps, project.args['app_id']) + development_step = get_db_model_from_hash_id(data_to_hash, DevelopmentSteps, project.args['app_id'], project.checkpoints['last_development_step']) return development_step def save_command_run(project, command, cli_response): @@ -228,7 +228,7 @@ def save_command_run(project, command, cli_response): data_fields = { 'command': command, 'cli_response': cli_response, - 'previous_command_run': project.checkpoints['last_command_run'], + 'previous_step': project.checkpoints['last_command_run'], } command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, "Saved Command Run") project.checkpoints['last_command_run'] = command_run @@ -240,7 +240,7 @@ def get_command_run_from_hash_id(project, command): 'command': command, 'command_runs_count': project.command_runs_count } - command_run = get_db_model_from_hash_id(data_to_hash, CommandRuns, project.args['app_id']) + command_run = get_db_model_from_hash_id(data_to_hash, CommandRuns, project.args['app_id'], project.checkpoints['last_command_run']) return command_run def save_user_input(project, query, user_input): @@ -251,7 +251,7 @@ def save_user_input(project, query, user_input): data_fields = { 'query': query, 'user_input': user_input, - 'previous_user_input': project.checkpoints['last_user_input'], + 'previous_step': project.checkpoints['last_user_input'], } user_input = hash_and_save_step(UserInputs, project.args['app_id'], hash_data_args, data_fields, "Saved User Input") project.checkpoints['last_user_input'] = user_input @@ -262,21 +262,47 @@ def get_user_input_from_hash_id(project, query): 'query': query, 'user_inputs_count': project.user_inputs_count } - user_input = get_db_model_from_hash_id(data_to_hash, UserInputs, project.args['app_id']) + user_input = get_db_model_from_hash_id(data_to_hash, UserInputs, project.args['app_id'], project.checkpoints['last_user_input']) return user_input def delete_all_subsequent_steps(project): - delete_subsequent_steps(DevelopmentSteps, project.checkpoints['last_development_step'], 'previous_dev_step') - delete_subsequent_steps(CommandRuns, project.checkpoints['last_command_run'], 'previous_command_run') - delete_subsequent_steps(UserInputs, project.checkpoints['last_user_input'], 'previous_user_input') + delete_subsequent_steps(DevelopmentSteps, project.checkpoints['last_development_step'], 'previous_step') + delete_subsequent_steps(CommandRuns, project.checkpoints['last_command_run'], 'previous_step') + delete_subsequent_steps(UserInputs, project.checkpoints['last_user_input'], 'previous_step') def delete_subsequent_steps(model, step, step_field_name): + if step is None: + return print(colored(f"Deleting subsequent {model.__name__} steps after {step.id}", "red")) subsequent_step = model.get_or_none(**{step_field_name: step.id}) if subsequent_step: delete_subsequent_steps(model, subsequent_step, step_field_name) subsequent_step.delete_instance() +def get_all_connected_steps(step, previous_step_field_name): + """Recursively get all steps connected to the given step.""" + connected_steps = [step] + prev_step = getattr(step, previous_step_field_name) + while prev_step is not None: + connected_steps.append(prev_step) + prev_step = getattr(prev_step, previous_step_field_name) + return connected_steps + +def delete_unconnected_steps_from(step, previous_step_field_name): + if step is None: + return + connected_steps = get_all_connected_steps(step, previous_step_field_name) + connected_step_ids = [s.id for s in connected_steps] + + unconnected_steps = DevelopmentSteps.select().where( + (DevelopmentSteps.app == step.app) & + (DevelopmentSteps.id.not_in(connected_step_ids)) + ).order_by(DevelopmentSteps.id.desc()) + + for unconnected_step in unconnected_steps: + print(colored(f"Deleting unconnected {step.__class__.__name__} step {unconnected_step.id}", "red")) + unconnected_step.delete_instance() + def create_tables(): with database: database.create_tables([ diff --git a/euclid/database/models/command_runs.py b/euclid/database/models/command_runs.py index d7f77c9..b6c34c4 100644 --- a/euclid/database/models/command_runs.py +++ b/euclid/database/models/command_runs.py @@ -6,11 +6,11 @@ from database.models.app import App class CommandRuns(BaseModel): id = AutoField() - app = ForeignKeyField(App) + app = ForeignKeyField(App, on_delete='CASCADE') hash_id = CharField(null=False) command = TextField(null=True) cli_response = TextField(null=True) - previous_command_run = ForeignKeyField('self', null=True, column_name='previous_command_run') + previous_step = ForeignKeyField('self', null=True, column_name='previous_step') class Meta: db_table = 'command_runs' diff --git a/euclid/database/models/components/progress_step.py b/euclid/database/models/components/progress_step.py index 92cb87b..281659e 100644 --- a/euclid/database/models/components/progress_step.py +++ b/euclid/database/models/components/progress_step.py @@ -7,7 +7,7 @@ from database.models.app import App class ProgressStep(BaseModel): - app = ForeignKeyField(App, primary_key=True) + app = ForeignKeyField(App, primary_key=True, on_delete='CASCADE') step = CharField() data = BinaryJSONField(null=True) messages = BinaryJSONField(null=True) diff --git a/euclid/database/models/development_steps.py b/euclid/database/models/development_steps.py index 6df9433..b0414bd 100644 --- a/euclid/database/models/development_steps.py +++ b/euclid/database/models/development_steps.py @@ -8,11 +8,11 @@ from database.models.app import App class DevelopmentSteps(BaseModel): id = AutoField() # This will serve as the primary key - app = ForeignKeyField(App) + app = ForeignKeyField(App, on_delete='CASCADE') hash_id = CharField(null=False) messages = BinaryJSONField(null=True) llm_response = BinaryJSONField(null=False) - previous_dev_step = ForeignKeyField('self', null=True, column_name='previous_dev_step') + previous_step = ForeignKeyField('self', null=True, column_name='previous_step') class Meta: db_table = 'development_steps' diff --git a/euclid/database/models/file_snapshot.py b/euclid/database/models/file_snapshot.py index 35d5c3a..483f2ee 100644 --- a/euclid/database/models/file_snapshot.py +++ b/euclid/database/models/file_snapshot.py @@ -4,7 +4,7 @@ from database.models.components.base_models import BaseModel from database.models.development_steps import DevelopmentSteps class FileSnapshot(BaseModel): - development_step = ForeignKeyField(DevelopmentSteps, backref='files') + development_step = ForeignKeyField(DevelopmentSteps, backref='files', on_delete='CASCADE') name = CharField() content = TextField() diff --git a/euclid/database/models/files.py b/euclid/database/models/files.py index fc56d2d..9719223 100644 --- a/euclid/database/models/files.py +++ b/euclid/database/models/files.py @@ -5,7 +5,7 @@ from database.models.development_steps import DevelopmentSteps from database.models.app import App class File(BaseModel): - app = ForeignKeyField(App) + app = ForeignKeyField(App, on_delete='CASCADE') name = CharField() path = CharField() description = TextField() diff --git a/euclid/database/models/user_inputs.py b/euclid/database/models/user_inputs.py index 7838872..7d2451c 100644 --- a/euclid/database/models/user_inputs.py +++ b/euclid/database/models/user_inputs.py @@ -6,11 +6,11 @@ from database.models.app import App class UserInputs(BaseModel): id = AutoField() - app = ForeignKeyField(App) + app = ForeignKeyField(App, on_delete='CASCADE') hash_id = CharField(null=False) query = TextField(null=True) user_input = TextField(null=True) - previous_user_input = ForeignKeyField('self', null=True, column_name='previous_user_input') + previous_step = ForeignKeyField('self', null=True, column_name='previous_step') class Meta: db_table = 'user_inputs' diff --git a/euclid/helpers/AgentConvo.py b/euclid/helpers/AgentConvo.py index bc1abc6..6ee00a1 100644 --- a/euclid/helpers/AgentConvo.py +++ b/euclid/helpers/AgentConvo.py @@ -42,6 +42,8 @@ class AgentConvo: if self.agent.project.skip_until_dev_step and str(development_step.id) == self.agent.project.skip_until_dev_step: self.agent.project.skip_steps = False delete_all_subsequent_steps(self.agent.project) + if 'delete_unrelated_steps' in self.agent.project.args and self.agent.project.args['delete_unrelated_steps']: + self.agent.project.delete_all_steps_except_current_branch() else: # if we don't, get the response from LLM response = create_gpt_chat_completion(self.messages, self.high_level_step, function_calls=function_calls) diff --git a/euclid/helpers/Project.py b/euclid/helpers/Project.py index bac15e7..61157de 100644 --- a/euclid/helpers/Project.py +++ b/euclid/helpers/Project.py @@ -3,7 +3,7 @@ import os from termcolor import colored from const.common import IGNORE_FOLDERS from database.models.app import App -from database.database import get_app +from database.database import get_app, delete_unconnected_steps_from from utils.questionary import styled_text from helpers.files import get_files_content, clear_directory from helpers.cli import build_directory_tree @@ -122,6 +122,11 @@ class Project: with open(full_path, 'w', encoding='utf-8') as f: f.write(file_snapshot.content) + def delete_all_steps_except_current_branch(self): + delete_unconnected_steps_from(self.checkpoints['last_development_step'], 'previous_step') + delete_unconnected_steps_from(self.checkpoints['last_command_run'], 'previous_step') + delete_unconnected_steps_from(self.checkpoints['last_user_input'], 'previous_step') + def ask_for_human_intervention(self, message, description): print(colored(message, "yellow")) print(description)