mirror of
https://github.com/OMGeeky/gpt-pilot.git
synced 2026-02-23 15:49:50 +01:00
Improved JSON prompting for GPT-4 and recover incomplete JSON responses from Code Llama
This commit is contained in:
@@ -70,8 +70,7 @@ def parse_agent_response(response, function_calls: FunctionCallSet | None):
|
||||
"""
|
||||
|
||||
if function_calls:
|
||||
text = re.sub(r'^.*```json\s*', '', response['text'], flags=re.DOTALL)
|
||||
text = text.strip('` \n')
|
||||
text = response['text']
|
||||
values = list(json.loads(text).values())
|
||||
if len(values) == 1:
|
||||
return values[0]
|
||||
@@ -140,7 +139,7 @@ class JsonPrompter:
|
||||
return "\n".join(
|
||||
self.function_descriptions(functions, function_to_call)
|
||||
+ [
|
||||
"The response MUST be a JSON object matching this schema:",
|
||||
"Here is the schema for the expected JSON object:",
|
||||
"```json",
|
||||
self.function_parameters(functions, function_to_call),
|
||||
"```",
|
||||
@@ -194,7 +193,7 @@ class JsonPrompter:
|
||||
system = (
|
||||
"Help choose the appropriate function to call to answer the user's question."
|
||||
if function_to_call is None
|
||||
else f"Define the arguments for {function_to_call} to answer the user's question."
|
||||
else f"Please provide a JSON object that defines the arguments for the `{function_to_call}` function to answer the user's question."
|
||||
) + "\nThe response must contain ONLY the JSON object, with NO additional text or explanation."
|
||||
|
||||
data = (
|
||||
@@ -202,11 +201,6 @@ class JsonPrompter:
|
||||
if function_to_call
|
||||
else self.functions_summary(functions)
|
||||
)
|
||||
response_start = (
|
||||
f"Here are the arguments for the `{function_to_call}` function: ```json\n"
|
||||
if function_to_call
|
||||
else "Here's the function the user should call: "
|
||||
)
|
||||
|
||||
if self.is_instruct:
|
||||
return f"[INST] <<SYS>>\n{system}\n\n{data}\n<</SYS>>\n\n{prompt} [/INST]"
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import List
|
||||
from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS
|
||||
from logger.logger import logger
|
||||
from helpers.exceptions.TokenLimitError import TokenLimitError
|
||||
from utils.utils import fix_json
|
||||
from utils.utils import fix_json, get_prompt
|
||||
from utils.function_calling import add_function_calls_to_request, FunctionCallSet, FunctionType
|
||||
|
||||
|
||||
@@ -148,6 +148,11 @@ def retry_on_exception(func):
|
||||
err_str = str(e)
|
||||
|
||||
# If the specific error "context_length_exceeded" is present, simply return without retry
|
||||
if isinstance(e, json.JSONDecodeError):
|
||||
# codellama-34b-instruct seems to send incomplete JSON responses
|
||||
if e.msg == 'Expecting value':
|
||||
args[0]['function_buffer'] = e.doc
|
||||
continue
|
||||
if "context_length_exceeded" in err_str:
|
||||
raise TokenLimitError(get_tokens_in_messages_from_openai_error(err_str), MAX_GPT_MODEL_TOKENS)
|
||||
if "rate_limit_exceeded" in err_str:
|
||||
@@ -187,14 +192,20 @@ def stream_gpt_completion(data, req_type):
|
||||
# TODO add type dynamically - this isn't working when connected to the external process
|
||||
terminal_width = 50 # os.get_terminal_size().columns
|
||||
lines_printed = 2
|
||||
gpt_response = ''
|
||||
buffer = '' # A buffer to accumulate incoming data
|
||||
expecting_json = False
|
||||
expecting_json = None
|
||||
received_json = False
|
||||
|
||||
if 'functions' in data:
|
||||
expecting_json = data['functions']
|
||||
if 'function_buffer' in data:
|
||||
incomplete_json = get_prompt('utils/incomplete_json.prompt', {'received_json': data['function_buffer']})
|
||||
data['messages'].append({'role': 'user', 'content': incomplete_json})
|
||||
gpt_response = data['function_buffer']
|
||||
received_json = True
|
||||
# Don't send the `functions` parameter to Open AI, but don't remove it from `data` in case we need to retry
|
||||
data = {key: value for key, value in data.items() if key != "functions"}
|
||||
data = {key: value for key, value in data.items() if not key.startswith('function')}
|
||||
|
||||
def return_result(result_data, lines_printed):
|
||||
if buffer:
|
||||
@@ -251,7 +262,6 @@ def stream_gpt_completion(data, req_type):
|
||||
logger.debug(f'problem with request: {response.text}')
|
||||
raise Exception(f"API responded with status code: {response.status_code}. Response text: {response.text}")
|
||||
|
||||
gpt_response = ''
|
||||
# function_calls = {'name': '', 'arguments': ''}
|
||||
|
||||
for line in response.iter_lines():
|
||||
@@ -283,11 +293,9 @@ def stream_gpt_completion(data, req_type):
|
||||
# return return_result({'function_calls': function_calls}, lines_printed)
|
||||
|
||||
json_line = choice['delta']
|
||||
# TODO: token healing? https://github.com/1rgs/jsonformer-claude
|
||||
# ...Is this what local_llm_function_calling.constrainer is for?
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f'Unable to decode line: {line}')
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f'Unable to decode line: {line} {e.msg}')
|
||||
continue # skip to the next line
|
||||
|
||||
# handle the streaming response
|
||||
@@ -306,16 +314,9 @@ def stream_gpt_completion(data, req_type):
|
||||
buffer += content # accumulate the data
|
||||
|
||||
# If you detect a natural breakpoint (e.g., line break or end of a response object), print & count:
|
||||
if buffer.endswith("\n"):
|
||||
if buffer.endswith('\n'):
|
||||
if expecting_json and not received_json:
|
||||
received_json = assert_json_response(buffer, lines_printed > 2)
|
||||
if received_json:
|
||||
gpt_response = ""
|
||||
# if not received_json:
|
||||
# # Don't append to gpt_response, but increment lines_printed
|
||||
# lines_printed += 1
|
||||
# buffer = ""
|
||||
# continue
|
||||
|
||||
# or some other condition that denotes a breakpoint
|
||||
lines_printed += count_lines_based_on_width(buffer, terminal_width)
|
||||
@@ -333,6 +334,7 @@ def stream_gpt_completion(data, req_type):
|
||||
logger.info(f'Response message: {gpt_response}')
|
||||
|
||||
if expecting_json:
|
||||
gpt_response = clean_json_response(gpt_response)
|
||||
assert_json_schema(gpt_response, expecting_json)
|
||||
|
||||
new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically
|
||||
@@ -348,12 +350,17 @@ def assert_json_response(response: str, or_fail=True) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def clean_json_response(response: str) -> str:
|
||||
response = re.sub(r'^.*```json\s*', '', response, flags=re.DOTALL)
|
||||
return response.strip('` \n')
|
||||
|
||||
|
||||
def assert_json_schema(response: str, functions: list[FunctionType]) -> True:
|
||||
for function in functions:
|
||||
schema = function['parameters']
|
||||
parsed = json.loads(response)
|
||||
validate(parsed, schema)
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def postprocessing(gpt_response, req_type):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from .files import setup_workspace
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from const.function_calls import ARCHITECTURE, DEV_STEPS
|
||||
from const.function_calls import ARCHITECTURE
|
||||
from utils.llm_connection import clean_json_response
|
||||
from .function_calling import parse_agent_response, JsonPrompter
|
||||
|
||||
|
||||
@@ -30,6 +31,7 @@ class TestFunctionCalling:
|
||||
function_calls = {'definitions': [], 'functions': {}}
|
||||
|
||||
# When
|
||||
response['text'] = clean_json_response(response['text'])
|
||||
response = parse_agent_response(response, function_calls)
|
||||
|
||||
# Then
|
||||
@@ -41,6 +43,7 @@ class TestFunctionCalling:
|
||||
function_calls = {'definitions': [], 'functions': {}}
|
||||
|
||||
# When
|
||||
response['text'] = clean_json_response(response['text'])
|
||||
response = parse_agent_response(response, function_calls)
|
||||
|
||||
# Then
|
||||
@@ -68,7 +71,7 @@ def test_json_prompter():
|
||||
|
||||
# Then
|
||||
assert prompt == '''Help choose the appropriate function to call to answer the user's question.
|
||||
The response should contain only the JSON object, with no additional text or explanation.
|
||||
The response must contain ONLY the JSON object, with NO additional text or explanation.
|
||||
|
||||
Available functions:
|
||||
- process_technologies - Print the list of technologies that are created.
|
||||
@@ -86,7 +89,7 @@ def test_llama_json_prompter():
|
||||
# Then
|
||||
assert prompt == '''[INST] <<SYS>>
|
||||
Help choose the appropriate function to call to answer the user's question.
|
||||
The response should contain only the JSON object, with no additional text or explanation.
|
||||
The response must contain ONLY the JSON object, with NO additional text or explanation.
|
||||
|
||||
Available functions:
|
||||
- process_technologies - Print the list of technologies that are created.
|
||||
@@ -103,11 +106,11 @@ def test_json_prompter_named():
|
||||
prompt = prompter.prompt('Create a web-based chat app', ARCHITECTURE['definitions'], 'process_technologies')
|
||||
|
||||
# Then
|
||||
assert prompt == '''Define the arguments for process_technologies to answer the user's question.
|
||||
The response should contain only the JSON object, with no additional text or explanation.
|
||||
assert prompt == '''Please provide a JSON object that defines the arguments for the `process_technologies` function to answer the user's question.
|
||||
The response must contain ONLY the JSON object, with NO additional text or explanation.
|
||||
|
||||
Print the list of technologies that are created.
|
||||
The response should be a JSON object matching this schema:
|
||||
# process_technologies: Print the list of technologies that are created.
|
||||
Here is the schema for the expected JSON object:
|
||||
```json
|
||||
{
|
||||
"technologies": {
|
||||
@@ -133,11 +136,11 @@ def test_llama_json_prompter_named():
|
||||
|
||||
# Then
|
||||
assert prompt == '''[INST] <<SYS>>
|
||||
Define the arguments for process_technologies to answer the user's question.
|
||||
The response should contain only the JSON object, with no additional text or explanation.
|
||||
Please provide a JSON object that defines the arguments for the `process_technologies` function to answer the user's question.
|
||||
The response must contain ONLY the JSON object, with NO additional text or explanation.
|
||||
|
||||
Print the list of technologies that are created.
|
||||
The response should be a JSON object matching this schema:
|
||||
# process_technologies: Print the list of technologies that are created.
|
||||
Here is the schema for the expected JSON object:
|
||||
```json
|
||||
{
|
||||
"technologies": {
|
||||
|
||||
@@ -2,6 +2,7 @@ import builtins
|
||||
from json import JSONDecodeError
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from dotenv import load_dotenv
|
||||
from jsonschema import ValidationError
|
||||
|
||||
@@ -12,7 +13,8 @@ from helpers.agents.Architect import Architect
|
||||
from helpers.agents.TechLead import TechLead
|
||||
from utils.function_calling import parse_agent_response, FunctionType
|
||||
from test.test_utils import assert_non_empty_string
|
||||
from .llm_connection import create_gpt_chat_completion, assert_json_response, assert_json_schema
|
||||
from test.mock_questionary import MockQuestionary
|
||||
from utils.llm_connection import create_gpt_chat_completion, stream_gpt_completion, assert_json_response, assert_json_schema
|
||||
from main import get_custom_print
|
||||
|
||||
load_dotenv()
|
||||
@@ -98,14 +100,42 @@ class TestLlmConnection:
|
||||
def setup_method(self):
|
||||
builtins.print, ipc_client_instance = get_custom_print({})
|
||||
|
||||
@patch('utils.llm_connection.requests.post')
|
||||
def test_stream_gpt_completion(self, mock_post):
|
||||
# Given streaming JSON response
|
||||
deltas = ['{', '\\n',
|
||||
' \\"foo\\": \\"bar\\",', '\\n',
|
||||
' \\"prompt\\": \\"Hello\\",', '\\n',
|
||||
' \\"choices\\": []', '\\n',
|
||||
'}']
|
||||
lines_to_yield = [
|
||||
('{"id": "gen-123", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "' + delta + '"}}]}')
|
||||
.encode('utf-8')
|
||||
for delta in deltas
|
||||
]
|
||||
lines_to_yield.insert(1, b': OPENROUTER PROCESSING') # Simulate OpenRoute keep-alive pings
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.iter_lines.return_value = lines_to_yield
|
||||
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# When
|
||||
with patch('utils.llm_connection.requests.post', return_value=mock_response):
|
||||
response = stream_gpt_completion({}, '')
|
||||
|
||||
# Then
|
||||
assert response == {'text': '{\n "foo": "bar",\n "prompt": "Hello",\n "choices": []\n}'}
|
||||
|
||||
|
||||
@pytest.mark.uses_tokens
|
||||
@pytest.mark.parametrize("endpoint, model", [
|
||||
("OPENAI", "gpt-4"), # role: system
|
||||
("OPENROUTER", "openai/gpt-3.5-turbo"), # role: user
|
||||
("OPENROUTER", "meta-llama/codellama-34b-instruct"), # rule: user, is_llama
|
||||
("OPENROUTER", "google/palm-2-chat-bison"), # role: user/system
|
||||
("OPENROUTER", "google/palm-2-codechat-bison"),
|
||||
("OPENROUTER", "anthropic/claude-2"), # role: user, is_llama
|
||||
@pytest.mark.parametrize('endpoint, model', [
|
||||
('OPENAI', 'gpt-4'), # role: system
|
||||
('OPENROUTER', 'openai/gpt-3.5-turbo'), # role: user
|
||||
('OPENROUTER', 'meta-llama/codellama-34b-instruct'), # rule: user, is_llama
|
||||
('OPENROUTER', 'google/palm-2-chat-bison'), # role: user/system
|
||||
('OPENROUTER', 'google/palm-2-codechat-bison'),
|
||||
('OPENROUTER', 'anthropic/claude-2'), # role: user, is_llama
|
||||
])
|
||||
def test_chat_completion_Architect(self, endpoint, model, monkeypatch):
|
||||
# Given
|
||||
@@ -154,13 +184,13 @@ solution-oriented decision-making in areas where precise instructions were not p
|
||||
assert 'Node.js' in response
|
||||
|
||||
@pytest.mark.uses_tokens
|
||||
@pytest.mark.parametrize("endpoint, model", [
|
||||
("OPENAI", "gpt-4"), # role: system
|
||||
("OPENROUTER", "openai/gpt-3.5-turbo"), # role: user
|
||||
("OPENROUTER", "meta-llama/codellama-34b-instruct"), # rule: user, is_llama
|
||||
("OPENROUTER", "google/palm-2-chat-bison"), # role: user/system
|
||||
("OPENROUTER", "google/palm-2-codechat-bison"),
|
||||
("OPENROUTER", "anthropic/claude-2"), # role: user, is_llama
|
||||
@pytest.mark.parametrize('endpoint, model', [
|
||||
('OPENAI', 'gpt-4'),
|
||||
('OPENROUTER', 'openai/gpt-3.5-turbo'),
|
||||
('OPENROUTER', 'meta-llama/codellama-34b-instruct'),
|
||||
('OPENROUTER', 'google/palm-2-chat-bison'),
|
||||
('OPENROUTER', 'google/palm-2-codechat-bison'),
|
||||
('OPENROUTER', 'anthropic/claude-2'),
|
||||
])
|
||||
def test_chat_completion_TechLead(self, endpoint, model, monkeypatch):
|
||||
# Given
|
||||
@@ -191,18 +221,22 @@ The development process will include the creation of user stories and tasks, bas
|
||||
})
|
||||
function_calls = DEVELOPMENT_PLAN
|
||||
|
||||
# Retry on bad LLM responses
|
||||
mock_questionary = MockQuestionary(['', '', 'no'])
|
||||
|
||||
# When
|
||||
response = create_gpt_chat_completion(convo.messages, '', function_calls=function_calls)
|
||||
with patch('utils.llm_connection.questionary', mock_questionary):
|
||||
response = create_gpt_chat_completion(convo.messages, '', function_calls=function_calls)
|
||||
|
||||
# Then
|
||||
assert convo.messages[0]['content'].startswith('You are a tech lead in a software development agency')
|
||||
assert convo.messages[1]['content'].startswith('You are working in a software development agency and a project manager and software architect approach you')
|
||||
# Then
|
||||
assert convo.messages[0]['content'].startswith('You are a tech lead in a software development agency')
|
||||
assert convo.messages[1]['content'].startswith('You are working in a software development agency and a project manager and software architect approach you')
|
||||
|
||||
assert response is not None
|
||||
response = parse_agent_response(response, function_calls)
|
||||
assert_non_empty_string(response[0]['description'])
|
||||
assert_non_empty_string(response[0]['programmatic_goal'])
|
||||
assert_non_empty_string(response[0]['user_review_goal'])
|
||||
assert response is not None
|
||||
response = parse_agent_response(response, function_calls)
|
||||
assert_non_empty_string(response[0]['description'])
|
||||
assert_non_empty_string(response[0]['programmatic_goal'])
|
||||
assert_non_empty_string(response[0]['user_review_goal'])
|
||||
|
||||
|
||||
# def test_break_down_development_task(self):
|
||||
|
||||
Reference in New Issue
Block a user