From e5e0f56d2e6accb2c0078cf7e75a8dc524576907 Mon Sep 17 00:00:00 2001 From: Nicholas Albion Date: Tue, 26 Sep 2023 19:38:24 +1000 Subject: [PATCH] JSON validation working --- pilot/const/function_calls.py | 2 +- pilot/utils/function_calling.py | 4 ++-- pilot/utils/llm_connection.py | 15 +++++------- pilot/utils/test_llm_connection.py | 37 +++++++++++++++++++++++++++--- requirements.txt | 1 + 5 files changed, 44 insertions(+), 15 deletions(-) diff --git a/pilot/const/function_calls.py b/pilot/const/function_calls.py index c271943..2020b7e 100644 --- a/pilot/const/function_calls.py +++ b/pilot/const/function_calls.py @@ -369,7 +369,7 @@ DEVELOPMENT_PLAN = { 'description': 'user-review goal that will determine if a task is done or not but from a user perspective since it will be reviewed by a human', } }, - 'required': ['task_description', 'programmatic_goal', 'user_review_goal'], + 'required': ['description', 'programmatic_goal', 'user_review_goal'], }, }, }, diff --git a/pilot/utils/function_calling.py b/pilot/utils/function_calling.py index 469bc53..4b738f9 100644 --- a/pilot/utils/function_calling.py +++ b/pilot/utils/function_calling.py @@ -140,7 +140,7 @@ class JsonPrompter: return "\n".join( self.function_descriptions(functions, function_to_call) + [ - "The response should be a JSON object matching this schema:", + "The response MUST be a JSON object matching this schema:", "```json", self.function_parameters(functions, function_to_call), "```", @@ -195,7 +195,7 @@ class JsonPrompter: "Help choose the appropriate function to call to answer the user's question." if function_to_call is None else f"Define the arguments for {function_to_call} to answer the user's question." - ) + "\nThe response should contain only the JSON object, with no additional text or explanation." + ) + "\nThe response must contain ONLY the JSON object, with NO additional text or explanation." data = ( self.function_data(functions, function_to_call) diff --git a/pilot/utils/llm_connection.py b/pilot/utils/llm_connection.py index 2492598..f3afcf1 100644 --- a/pilot/utils/llm_connection.py +++ b/pilot/utils/llm_connection.py @@ -7,6 +7,7 @@ import json import tiktoken import questionary +from jsonschema import validate from utils.style import red from typing import List from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS @@ -15,6 +16,7 @@ from helpers.exceptions.TokenLimitError import TokenLimitError from utils.utils import fix_json from utils.function_calling import add_function_calls_to_request, FunctionCallSet, FunctionType + def get_tokens_in_messages(messages: List[str]) -> int: tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer tokenized_messages = [tokenizer.encode(message['content']) for message in messages] @@ -347,16 +349,11 @@ def assert_json_response(response: str, or_fail=True) -> bool: def assert_json_schema(response: str, functions: list[FunctionType]) -> True: + for function in functions: + schema = function['parameters'] + parsed = json.loads(response) + validate(parsed, schema) return True - # TODO: validation always fails - # for function in functions: - # schema = function['parameters'] - # parser = parser_for_schema(schema) - # validated = parser.validate(response) - # if validated.valid and validated.end_index: - # return True - # - # raise ValueError('LLM responded with invalid JSON') def postprocessing(gpt_response, req_type): diff --git a/pilot/utils/test_llm_connection.py b/pilot/utils/test_llm_connection.py index ec55633..7f21d97 100644 --- a/pilot/utils/test_llm_connection.py +++ b/pilot/utils/test_llm_connection.py @@ -1,6 +1,9 @@ import builtins +from json import JSONDecodeError + import pytest from dotenv import load_dotenv +from jsonschema import ValidationError from const.function_calls import ARCHITECTURE, DEVELOPMENT_PLAN from helpers.AgentConvo import AgentConvo @@ -45,13 +48,13 @@ class TestSchemaValidation: def test_assert_json_schema_invalid(self): # When assert_json_schema is called with invalid JSON # Then error is raised - with pytest.raises(ValueError, match='LLM responded with invalid JSON'): + with pytest.raises(ValidationError, match="1 is not of type 'string'"): assert_json_schema('{"foo": 1}', [self.function]) def test_assert_json_schema_incomplete(self): # When assert_json_schema is called with incomplete JSON # Then error is raised - with pytest.raises(ValueError, match='LLM responded with invalid JSON'): + with pytest.raises(JSONDecodeError): assert_json_schema('{"foo": "b', [self.function]) def test_assert_json_schema_required(self): @@ -60,9 +63,37 @@ class TestSchemaValidation: self.function['parameters']['properties']['other'] = {'type': 'string'} self.function['parameters']['required'] = ['foo', 'other'] - with pytest.raises(ValueError, match='LLM responded with invalid JSON'): + with pytest.raises(ValidationError, match="'other' is a required property"): assert_json_schema('{"foo": "bar"}', [self.function]) + def test_DEVELOPMENT_PLAN(self): + assert(assert_json_schema(''' +{ + "plan": [ + { + "description": "Set up project structure including creation of necessary directories and files. Initialize Node.js and install necessary libraries such as express and socket.io.", + "programmatic_goal": "Project structure should be set up and Node.js initialized. Express and socket.io libraries should be installed and reflected in the package.json file.", + "user_review_goal": "Developer should be able to start an empty express server by running `npm start` command without any errors." + }, + { + "description": "Create a simple front-end HTML page with CSS and JavaScript that includes input for typing messages and area for displaying messages.", + "programmatic_goal": "There should be an HTML file containing an input box for typing messages and an area for displaying the messages. This HTML page should be served when user navigates to the root URL.", + "user_review_goal": "Navigating to the root URL (http://localhost:3000) should display the chat front-end with an input box and a message area." + }, + { + "description": "Set up socket.io on the back-end to handle websocket connections and broadcasting messages to the clients.", + "programmatic_goal": "Server should be able to handle websocket connections using socket.io and broadcast messages to all connected clients.", + "user_review_goal": "By using two different browsers or browser tabs, when one user sends a message from one tab, it should appear in the other user's browser tab in real-time." + }, + { + "description": "Integrate front-end with socket.io client to send messages from the input field to the server and display incoming messages in the message area.", + "programmatic_goal": "Front-end should be able to send messages to server and display incoming messages in the message area using socket.io client.", + "user_review_goal": "Typing a message in the chat input and sending it should then display the message in the chat area." + } + ] +} +'''.strip(), DEVELOPMENT_PLAN['definitions'])) + class TestLlmConnection: def setup_method(self): builtins.print, ipc_client_instance = get_custom_print({}) diff --git a/requirements.txt b/requirements.txt index 7a4eeca..fbb89d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ certifi==2023.5.7 charset-normalizer==3.2.0 distro==1.8.0 idna==3.4 +jsonschema==4.19.1 Jinja2==3.1.2 MarkupSafe==2.1.3 peewee==3.16.2