diff --git a/euclid/database/database.py b/euclid/database/database.py index d869e7c..ee4ca8e 100644 --- a/euclid/database/database.py +++ b/euclid/database/database.py @@ -137,7 +137,7 @@ def get_progress_steps(app_id, step=None): return steps -def save_development_step(app_id, prompt_path, prompt_data, llm_req_num, messages, response): +def save_development_step(app_id, prompt_path, prompt_data, llm_req_num, messages, response, previous_step=None): app = get_app(app_id) hash_id = hash_data({ 'prompt_path': prompt_path, @@ -146,11 +146,11 @@ def save_development_step(app_id, prompt_path, prompt_data, llm_req_num, message }) try: inserted_id = (DevelopmentSteps - .insert(app=app, hash_id=hash_id, messages=messages, llm_response=response) - .on_conflict(conflict_target=[DevelopmentSteps.app, DevelopmentSteps.hash_id], - preserve=[DevelopmentSteps.messages, DevelopmentSteps.llm_response], - update={}) - .execute()) + .insert(app=app, hash_id=hash_id, messages=messages, llm_response=response, previous_dev_step=previous_step) + .on_conflict(conflict_target=[DevelopmentSteps.app, DevelopmentSteps.hash_id], + preserve=[DevelopmentSteps.messages, DevelopmentSteps.llm_response], + update={}) + .execute()) dev_step = DevelopmentSteps.get_by_id(inserted_id) print(colored(f"Saved DEV step => {dev_step.id}", "yellow")) @@ -184,8 +184,8 @@ def hash_and_save_step(Model, app_id, hash_data_args, data_fields, message): inserted_id = (Model .insert(**data_to_insert) .on_conflict(conflict_target=[Model.app, Model.hash_id], - preserve=[field for field in data_fields.keys()], - update={}) + preserve=[], + update=data_fields) .execute()) record = Model.get_by_id(inserted_id) @@ -204,8 +204,11 @@ def save_command_run(project, command, cli_response): data_fields = { 'command': command, 'cli_response': cli_response, + 'previous_command_run': project.checkpoints['last_command_run'], } - return hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, "Saved Command Run") + command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, "Saved Command Run") + project.checkpoints['last_command_run'] = command_run + return command_run def get_command_run_from_hash_id(project, command): @@ -213,7 +216,8 @@ def get_command_run_from_hash_id(project, command): 'command': command, 'command_runs_count': project.command_runs_count } - return get_db_model_from_hash_id(data_to_hash, CommandRuns, project.args['app_id']) + command_run = get_db_model_from_hash_id(data_to_hash, CommandRuns, project.args['app_id']) + return command_run def save_user_input(project, query, user_input): hash_data_args = { @@ -223,15 +227,19 @@ def save_user_input(project, query, user_input): data_fields = { 'query': query, 'user_input': user_input, + 'previous_user_input': project.checkpoints['last_user_input'], } - return hash_and_save_step(UserInputs, project.args['app_id'], hash_data_args, data_fields, "Saved User Input") + user_input = hash_and_save_step(UserInputs, project.args['app_id'], hash_data_args, data_fields, "Saved User Input") + project.checkpoints['last_user_input'] = user_input + return user_input def get_user_input_from_hash_id(project, query): data_to_hash = { 'query': query, 'user_inputs_count': project.user_inputs_count } - return get_db_model_from_hash_id(data_to_hash, UserInputs, project.args['app_id']) + user_input = get_db_model_from_hash_id(data_to_hash, UserInputs, project.args['app_id']) + return user_input def get_development_step_from_hash_id(app_id, prompt_path, prompt_data, llm_req_num): diff --git a/euclid/database/models/command_runs.py b/euclid/database/models/command_runs.py index 673ff61..d7f77c9 100644 --- a/euclid/database/models/command_runs.py +++ b/euclid/database/models/command_runs.py @@ -10,6 +10,7 @@ class CommandRuns(BaseModel): hash_id = CharField(null=False) command = TextField(null=True) cli_response = TextField(null=True) + previous_command_run = ForeignKeyField('self', null=True, column_name='previous_command_run') class Meta: db_table = 'command_runs' diff --git a/euclid/database/models/development_steps.py b/euclid/database/models/development_steps.py index 3ba05b5..6df9433 100644 --- a/euclid/database/models/development_steps.py +++ b/euclid/database/models/development_steps.py @@ -12,6 +12,7 @@ class DevelopmentSteps(BaseModel): hash_id = CharField(null=False) messages = BinaryJSONField(null=True) llm_response = BinaryJSONField(null=False) + previous_dev_step = ForeignKeyField('self', null=True, column_name='previous_dev_step') class Meta: db_table = 'development_steps' diff --git a/euclid/database/models/user_inputs.py b/euclid/database/models/user_inputs.py index 6705bd4..7838872 100644 --- a/euclid/database/models/user_inputs.py +++ b/euclid/database/models/user_inputs.py @@ -10,6 +10,7 @@ class UserInputs(BaseModel): hash_id = CharField(null=False) query = TextField(null=True) user_input = TextField(null=True) + previous_user_input = ForeignKeyField('self', null=True, column_name='previous_user_input') class Meta: db_table = 'user_inputs' diff --git a/euclid/helpers/AgentConvo.py b/euclid/helpers/AgentConvo.py index e3e59e4..962f9ba 100644 --- a/euclid/helpers/AgentConvo.py +++ b/euclid/helpers/AgentConvo.py @@ -36,13 +36,15 @@ class AgentConvo: if self.agent.project.skip_until_dev_step and str(development_step.id) == self.agent.project.skip_until_dev_step: self.agent.project.skip_steps = False print(colored(f'Restoring development step with id {development_step.id}', 'yellow')) + self.agent.project.checkpoints['last_development_step'] = development_step self.agent.project.restore_files(development_step.id) response = development_step.llm_response self.messages = development_step.messages else: # if we don't, get the response from LLM response = create_gpt_chat_completion(self.messages, self.high_level_step, function_calls=function_calls) - development_step = save_development_step(self.agent.project.args['app_id'], prompt_path, prompt_data, self.agent.project.llm_req_num, self.messages, response) + development_step = save_development_step(self.agent.project.args['app_id'], prompt_path, prompt_data, self.agent.project.llm_req_num, self.messages, response, self.agent.project.checkpoints['last_development_step']) + self.agent.project.checkpoints['last_development_step'] = development_step self.agent.project.save_files_snapshot(development_step.id) # TODO handle errors from OpenAI diff --git a/euclid/helpers/Project.py b/euclid/helpers/Project.py index 8010bfa..366e65e 100644 --- a/euclid/helpers/Project.py +++ b/euclid/helpers/Project.py @@ -24,6 +24,11 @@ class Project: self.llm_req_num = 0 self.command_runs_count = 0 self.user_inputs_count = 0 + self.checkpoints = { + 'last_user_input': None, + 'last_command_run': None, + 'last_development_step': None, + } self.skip_steps = False if ('skip_until_dev_step' in args and args['skip_until_dev_step'] == '0') else True self.skip_until_dev_step = args['skip_until_dev_step'] if 'skip_until_dev_step' in args else None # TODO make flexible diff --git a/euclid/helpers/cli.py b/euclid/helpers/cli.py index dba9eb5..d5fb30e 100644 --- a/euclid/helpers/cli.py +++ b/euclid/helpers/cli.py @@ -48,6 +48,7 @@ def execute_command(project, command, timeout=5000): command_run = get_command_run_from_hash_id(project, command) if command_run is not None and project.skip_steps: # if we do, use it + project.checkpoints['last_command_run'] = command_run print(colored(f'Restoring command run response id {command_run.id}:\n```\n{command_run.cli_response}```', 'yellow')) return command_run.cli_response diff --git a/euclid/utils/questionary.py b/euclid/utils/questionary.py index d9fc463..efe845d 100644 --- a/euclid/utils/questionary.py +++ b/euclid/utils/questionary.py @@ -23,6 +23,7 @@ def styled_text(project, question): user_input = get_user_input_from_hash_id(project, question) if user_input is not None and project.skip_steps: # if we do, use it + project.checkpoints['last_user_input'] = user_input print(colored(f'Restoring user input id {user_input.id}: ', 'yellow'), end='') print(colored(f'{user_input.user_input}', 'yellow', attrs=['bold'])) return user_input.user_input