diff --git a/pilot/helpers/AgentConvo.py b/pilot/helpers/AgentConvo.py index 1b2900d..25e309b 100644 --- a/pilot/helpers/AgentConvo.py +++ b/pilot/helpers/AgentConvo.py @@ -1,4 +1,3 @@ -import json import re import subprocess import uuid @@ -6,9 +5,9 @@ from utils.style import yellow, yellow_bold from database.database import get_saved_development_step, save_development_step, delete_all_subsequent_steps from helpers.exceptions.TokenLimitError import TokenLimitError -from utils.utils import array_of_objects_to_string, get_prompt +from utils.function_calling import parse_agent_response from utils.llm_connection import create_gpt_chat_completion -from utils.utils import get_sys_message, find_role_from_step, capitalize_first_word_with_underscores +from utils.utils import array_of_objects_to_string, get_prompt, get_sys_message, capitalize_first_word_with_underscores from logger.logger import logger from prompts.prompts import ask_user from const.llm import END_RESPONSE @@ -83,7 +82,7 @@ class AgentConvo: if response == {}: raise Exception("OpenAI API error happened.") - response = self.postprocess_response(response, function_calls) + response = parse_agent_response(response, function_calls) # TODO remove this once the database is set up properly message_content = response[0] if type(response) == tuple else response @@ -174,32 +173,6 @@ class AgentConvo: def convo_length(self): return len([msg for msg in self.messages if msg['role'] != 'system']) - def postprocess_response(self, response, function_calls): - """ - Post-processes the response from the agent. - - Args: - response: The response from the agent. - function_calls: Optional function calls associated with the response. - - Returns: - The post-processed response. - """ - if 'function_calls' in response and function_calls is not None: - if 'send_convo' in function_calls: - response['function_calls']['arguments']['convo'] = self - response = function_calls['functions'][response['function_calls']['name']](**response['function_calls']['arguments']) - elif 'text' in response: - if function_calls: - values = list(json.loads(response['text']).values()) - if len(values) == 1: - return values[0] - else: - return tuple(values) - else: - response = response['text'] - - return response def log_message(self, content): """ diff --git a/pilot/utils/function_calling.py b/pilot/utils/function_calling.py index dd36bc9..0ec3360 100644 --- a/pilot/utils/function_calling.py +++ b/pilot/utils/function_calling.py @@ -1,4 +1,5 @@ import json +import re # from local_llm_function_calling import Generator # from local_llm_function_calling.model.llama import LlamaModel # from local_llm_function_calling.model.huggingface import HuggingfaceModel @@ -40,6 +41,29 @@ def add_function_calls_to_request(gpt_data, function_calls: FunctionCallSet | No }) +def parse_agent_response(response, function_calls: FunctionCallSet | None): + """ + Post-processes the response from the agent. + + Args: + response: The response from the agent. + function_calls: Optional function calls associated with the response. + + Returns: + The post-processed response. + """ + + if function_calls: + text = re.sub(r'^```json\n', '', response['text']) + values = list(json.loads(text.strip('` \n')).values()) + if len(values) == 1: + return values[0] + else: + return tuple(values) + + return response['text'] + + class LlamaInstructPrompter: """ A prompter for Llama2 instruct models. diff --git a/pilot/utils/test_function_calling.py b/pilot/utils/test_function_calling.py index 978e68a..635e1c6 100644 --- a/pilot/utils/test_function_calling.py +++ b/pilot/utils/test_function_calling.py @@ -1,7 +1,64 @@ from local_llm_function_calling.prompter import CompletionModelPrompter, InstructModelPrompter from const.function_calls import ARCHITECTURE, DEV_STEPS -from .function_calling import JsonPrompter +from .function_calling import parse_agent_response, LlamaInstructPrompter + + +class TestFunctionCalling: + def test_parse_agent_response_text(self): + # Given + response = {'text': 'Hello world!'} + + # When + response = parse_agent_response(response, None) + + # Then + assert response == 'Hello world!' + + def test_parse_agent_response_json(self): + # Given + response = {'text': '{"greeting": "Hello world!"}'} + function_calls = {'definitions': [], 'functions': {}} + + # When + response = parse_agent_response(response, function_calls) + + # Then + assert response == 'Hello world!' + + def test_parse_agent_response_json_markdown(self): + # Given + response = {'text': '```json\n{"greeting": "Hello world!"}\n```'} + function_calls = {'definitions': [], 'functions': {}} + + # When + response = parse_agent_response(response, function_calls) + + # Then + assert response == 'Hello world!' + + def test_parse_agent_response_markdown(self): + # Given + response = {'text': '```\n{"greeting": "Hello world!"}\n```'} + function_calls = {'definitions': [], 'functions': {}} + + # When + response = parse_agent_response(response, function_calls) + + # Then + assert response == 'Hello world!' + + def test_parse_agent_response_multiple_args(self): + # Given + response = {'text': '{"greeting": "Hello", "name": "John"}'} + function_calls = {'definitions': [], 'functions': {}} + + # When + greeting, name = parse_agent_response(response, function_calls) + + # Then + assert greeting == 'Hello' + assert name == 'John' def test_completion_function_prompt(): @@ -62,23 +119,23 @@ Create a web-based chat app Function call: ''' -def test_json_prompter(): - # Given - prompter = JsonPrompter() - - # When - prompt = prompter.prompt('Create a web-based chat app', ARCHITECTURE['definitions']) # , 'process_technologies') - - # Then - assert prompt == '''[INST] <> -Help choose the appropriate function to call to answer the user's question. -In your response you must only use JSON output and provide no notes or commentary. - -Available functions: -- process_technologies - Print the list of technologies that are created. -<> - -Create a web-based chat app [/INST]''' +# def test_json_prompter(): +# # Given +# prompter = JsonPrompter() +# +# # When +# prompt = prompter.prompt('Create a web-based chat app', ARCHITECTURE['definitions']) # , 'process_technologies') +# +# # Then +# assert prompt == '''[INST] <> +# Help choose the appropriate function to call to answer the user's question. +# In your response you must only use JSON output and provide no notes or commentary. +# +# Available functions: +# - process_technologies - Print the list of technologies that are created. +# <> +# +# Create a web-based chat app [/INST]''' def test_llama_instruct_function_prompter_named():