Refactored saving previous steps and made steps be loaded by looking at the current step as well

This commit is contained in:
Zvonimir Sabljic
2023-08-10 08:41:46 +02:00
parent 6b7e77b46a
commit 01b50ab3e3
9 changed files with 54 additions and 21 deletions

View File

@@ -154,10 +154,10 @@ def get_progress_steps(app_id, step=None):
return steps
def get_db_model_from_hash_id(data_to_hash, model, app_id):
def get_db_model_from_hash_id(data_to_hash, model, app_id, previous_step):
hash_id = hash_data(data_to_hash)
try:
db_row = model.get((model.hash_id == hash_id) & (model.app == app_id))
db_row = model.get((model.hash_id == hash_id) & (model.app == app_id) & (model.previous_step == previous_step))
except DoesNotExist:
return None
return db_row
@@ -203,7 +203,7 @@ def save_development_step(project, prompt_path, prompt_data, messages, llm_respo
data_fields = {
'messages': messages,
'llm_response': llm_response,
'previous_dev_step': project.checkpoints['last_development_step'],
'previous_step': project.checkpoints['last_development_step'],
}
development_step = hash_and_save_step(DevelopmentSteps, project.args['app_id'], hash_data_args, data_fields, "Saved Development Step")
@@ -217,7 +217,7 @@ def get_development_step_from_hash_id(project, prompt_path, prompt_data, llm_req
'prompt_data': {} if prompt_data is None else {k: v for k, v in prompt_data.items() if k not in PROMPT_DATA_TO_IGNORE},
'llm_req_num': llm_req_num
}
development_step = get_db_model_from_hash_id(data_to_hash, DevelopmentSteps, project.args['app_id'])
development_step = get_db_model_from_hash_id(data_to_hash, DevelopmentSteps, project.args['app_id'], project.checkpoints['last_development_step'])
return development_step
def save_command_run(project, command, cli_response):
@@ -228,7 +228,7 @@ def save_command_run(project, command, cli_response):
data_fields = {
'command': command,
'cli_response': cli_response,
'previous_command_run': project.checkpoints['last_command_run'],
'previous_step': project.checkpoints['last_command_run'],
}
command_run = hash_and_save_step(CommandRuns, project.args['app_id'], hash_data_args, data_fields, "Saved Command Run")
project.checkpoints['last_command_run'] = command_run
@@ -240,7 +240,7 @@ def get_command_run_from_hash_id(project, command):
'command': command,
'command_runs_count': project.command_runs_count
}
command_run = get_db_model_from_hash_id(data_to_hash, CommandRuns, project.args['app_id'])
command_run = get_db_model_from_hash_id(data_to_hash, CommandRuns, project.args['app_id'], project.checkpoints['last_command_run'])
return command_run
def save_user_input(project, query, user_input):
@@ -251,7 +251,7 @@ def save_user_input(project, query, user_input):
data_fields = {
'query': query,
'user_input': user_input,
'previous_user_input': project.checkpoints['last_user_input'],
'previous_step': project.checkpoints['last_user_input'],
}
user_input = hash_and_save_step(UserInputs, project.args['app_id'], hash_data_args, data_fields, "Saved User Input")
project.checkpoints['last_user_input'] = user_input
@@ -262,21 +262,47 @@ def get_user_input_from_hash_id(project, query):
'query': query,
'user_inputs_count': project.user_inputs_count
}
user_input = get_db_model_from_hash_id(data_to_hash, UserInputs, project.args['app_id'])
user_input = get_db_model_from_hash_id(data_to_hash, UserInputs, project.args['app_id'], project.checkpoints['last_user_input'])
return user_input
def delete_all_subsequent_steps(project):
delete_subsequent_steps(DevelopmentSteps, project.checkpoints['last_development_step'], 'previous_dev_step')
delete_subsequent_steps(CommandRuns, project.checkpoints['last_command_run'], 'previous_command_run')
delete_subsequent_steps(UserInputs, project.checkpoints['last_user_input'], 'previous_user_input')
delete_subsequent_steps(DevelopmentSteps, project.checkpoints['last_development_step'], 'previous_step')
delete_subsequent_steps(CommandRuns, project.checkpoints['last_command_run'], 'previous_step')
delete_subsequent_steps(UserInputs, project.checkpoints['last_user_input'], 'previous_step')
def delete_subsequent_steps(model, step, step_field_name):
if step is None:
return
print(colored(f"Deleting subsequent {model.__name__} steps after {step.id}", "red"))
subsequent_step = model.get_or_none(**{step_field_name: step.id})
if subsequent_step:
delete_subsequent_steps(model, subsequent_step, step_field_name)
subsequent_step.delete_instance()
def get_all_connected_steps(step, previous_step_field_name):
"""Recursively get all steps connected to the given step."""
connected_steps = [step]
prev_step = getattr(step, previous_step_field_name)
while prev_step is not None:
connected_steps.append(prev_step)
prev_step = getattr(prev_step, previous_step_field_name)
return connected_steps
def delete_unconnected_steps_from(step, previous_step_field_name):
if step is None:
return
connected_steps = get_all_connected_steps(step, previous_step_field_name)
connected_step_ids = [s.id for s in connected_steps]
unconnected_steps = DevelopmentSteps.select().where(
(DevelopmentSteps.app == step.app) &
(DevelopmentSteps.id.not_in(connected_step_ids))
).order_by(DevelopmentSteps.id.desc())
for unconnected_step in unconnected_steps:
print(colored(f"Deleting unconnected {step.__class__.__name__} step {unconnected_step.id}", "red"))
unconnected_step.delete_instance()
def create_tables():
with database:
database.create_tables([