From 6e977590c391407d8bc2ff25abbb25c16485471b Mon Sep 17 00:00:00 2001 From: Zvonimir Sabljic Date: Tue, 12 Sep 2023 21:03:10 +0200 Subject: [PATCH] Replace file content in all messages each time we load a branch in AgentConvo - REFACTOR eventually so we don't deal with strings but with real data --- pilot/helpers/AgentConvo.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/pilot/helpers/AgentConvo.py b/pilot/helpers/AgentConvo.py index 533960a..53388a5 100644 --- a/pilot/helpers/AgentConvo.py +++ b/pilot/helpers/AgentConvo.py @@ -1,7 +1,10 @@ +import re import subprocess from termcolor import colored from database.database import get_development_step_from_hash_id, save_development_step, delete_all_subsequent_steps +from helpers.files import get_files_content +from const.common import IGNORE_FOLDERS from utils.utils import array_of_objects_to_string from utils.llm_connection import get_prompt, create_gpt_chat_completion from utils.utils import get_sys_message, find_role_from_step, capitalize_first_word_with_underscores @@ -132,8 +135,32 @@ class AgentConvo: def save_branch(self, branch_name): self.branches[branch_name] = self.messages.copy() - def load_branch(self, branch_name): + def load_branch(self, branch_name, reload_files=True): self.messages = self.branches[branch_name].copy() + if reload_files: + # TODO make this more flexible - with every message, save metadata so every time we load a branch, reconstruct all messages from scratch + self.replace_files() + + def replace_files(self): + files = self.agent.project.get_all_coded_files() + for msg in self.messages: + if msg['role'] == 'user': + for file in files: + self.replace_file_content(msg['content'], file['path'], file['content']) + + def replace_file_content(self, message, file_path, new_content): + escaped_file_path = re.escape(file_path) + + pattern = rf'\*\*{{ {escaped_file_path} }}\*\*\n```\n(.*?)\n```' + + new_section_content = f'**{{ {file_path} }}**\n```\n{new_content}\n```' + + updated_message, num_replacements = re.subn(pattern, new_section_content, message, flags=re.DOTALL) + + if num_replacements == 0: + return message + + return updated_message def convo_length(self): return len([msg for msg in self.messages if msg['role'] != 'system'])