mirror of
https://github.com/OMGeeky/gpt-pilot.git
synced 2026-02-23 15:49:50 +01:00
ARCHITECTURE function_calls works on meta-llama/codellama-34b-instruct
This commit is contained in:
169
pilot/utils/function_calling.py
Normal file
169
pilot/utils/function_calling.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
# from local_llm_function_calling import Generator
|
||||
# from local_llm_function_calling.model.llama import LlamaModel
|
||||
# from local_llm_function_calling.model.huggingface import HuggingfaceModel
|
||||
from local_llm_function_calling.prompter import FunctionType, CompletionModelPrompter, InstructModelPrompter
|
||||
# from local_llm_function_calling.model.llama import LlamaInstructPrompter
|
||||
|
||||
from typing import Literal, NotRequired, Protocol, TypeVar, TypedDict, Callable
|
||||
|
||||
|
||||
class FunctionCallSet(TypedDict):
|
||||
definitions: list[FunctionType]
|
||||
functions: dict[str, Callable]
|
||||
|
||||
|
||||
def add_function_calls_to_request(gpt_data, function_calls: FunctionCallSet | None):
|
||||
if function_calls is None:
|
||||
return
|
||||
|
||||
if gpt_data['model'] == 'gpt-4':
|
||||
gpt_data['functions'] = function_calls['definitions']
|
||||
if len(function_calls['definitions']) > 1:
|
||||
gpt_data['function_call'] = 'auto'
|
||||
else:
|
||||
gpt_data['function_call'] = {'name': function_calls['definitions'][0]['name']}
|
||||
return
|
||||
|
||||
# prompter = CompletionModelPrompter()
|
||||
# prompter = InstructModelPrompter()
|
||||
prompter = LlamaInstructPrompter()
|
||||
|
||||
if len(function_calls['definitions']) > 1:
|
||||
function_call = None
|
||||
else:
|
||||
function_call = function_calls['definitions'][0]['name']
|
||||
|
||||
gpt_data['messages'].append({
|
||||
'role': 'user',
|
||||
'content': prompter.prompt('', function_calls['definitions'], function_call)
|
||||
})
|
||||
|
||||
|
||||
class LlamaInstructPrompter:
|
||||
"""
|
||||
A prompter for Llama2 instruct models.
|
||||
Adapted from local_llm_function_calling
|
||||
"""
|
||||
|
||||
def function_descriptions(
|
||||
self, functions: list[FunctionType], function_to_call: str
|
||||
) -> list[str]:
|
||||
"""Get the descriptions of the functions
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the descriptions of
|
||||
function_to_call (str): The function to call
|
||||
|
||||
Returns:
|
||||
list[str]: The descriptions of the functions
|
||||
(empty if the function doesn't exist or has no description)
|
||||
"""
|
||||
return [
|
||||
"Function description: " + function["description"]
|
||||
for function in functions
|
||||
if function["name"] == function_to_call and "description" in function
|
||||
]
|
||||
|
||||
def function_parameters(
|
||||
self, functions: list[FunctionType], function_to_call: str
|
||||
) -> str:
|
||||
"""Get the parameters of the function
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the parameters of
|
||||
function_to_call (str): The function to call
|
||||
|
||||
Returns:
|
||||
str: The parameters of the function as a JSON schema
|
||||
"""
|
||||
return next(
|
||||
json.dumps(function["parameters"]["properties"], indent=4)
|
||||
for function in functions
|
||||
if function["name"] == function_to_call
|
||||
)
|
||||
|
||||
def function_data(
|
||||
self, functions: list[FunctionType], function_to_call: str
|
||||
) -> str:
|
||||
"""Get the data for the function
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the data for
|
||||
function_to_call (str): The function to call
|
||||
|
||||
Returns:
|
||||
str: The data necessary to generate the arguments for the function
|
||||
"""
|
||||
return "\n".join(
|
||||
self.function_descriptions(functions, function_to_call)
|
||||
+ [
|
||||
"Function parameters should follow this schema:",
|
||||
"```jsonschema",
|
||||
self.function_parameters(functions, function_to_call),
|
||||
"```",
|
||||
]
|
||||
)
|
||||
|
||||
def function_summary(self, function: FunctionType) -> str:
|
||||
"""Get a summary of a function
|
||||
|
||||
Args:
|
||||
function (FunctionType): The function to get the summary of
|
||||
|
||||
Returns:
|
||||
str: The summary of the function, as a bullet point
|
||||
"""
|
||||
return f"- {function['name']}" + (
|
||||
f" - {function['description']}" if "description" in function else ""
|
||||
)
|
||||
|
||||
def functions_summary(self, functions: list[FunctionType]) -> str:
|
||||
"""Get a summary of the functions
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the summary of
|
||||
|
||||
Returns:
|
||||
str: The summary of the functions, as a bulleted list
|
||||
"""
|
||||
return "Available functions:\n" + "\n".join(
|
||||
self.function_summary(function) for function in functions
|
||||
)
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
functions: list[FunctionType],
|
||||
function_to_call: str | None = None,
|
||||
) -> str:
|
||||
"""Generate the llama prompt
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate the response to
|
||||
functions (list[FunctionType]): The functions to generate the response from
|
||||
function_to_call (str | None): The function to call. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[bytes | int]: The llama prompt, a function selection prompt if no
|
||||
function is specified, or a function argument prompt if a function is
|
||||
specified
|
||||
"""
|
||||
system = (
|
||||
"Help choose the appropriate function to call to answer the user's question."
|
||||
if function_to_call is None
|
||||
else f"Define the arguments for {function_to_call} to answer the user's question."
|
||||
) + "In your response you must only use JSON output and provide no notes or commentary."
|
||||
data = (
|
||||
self.function_data(functions, function_to_call)
|
||||
if function_to_call
|
||||
else self.functions_summary(functions)
|
||||
)
|
||||
response_start = (
|
||||
f"Here are the arguments for the `{function_to_call}` function: ```json\n"
|
||||
if function_to_call
|
||||
else "Here's the function the user should call: "
|
||||
)
|
||||
return f"[INST] <<SYS>>\n{system}\n\n{data}\n<</SYS>>\n\n{prompt} [/INST]"
|
||||
# {response_start}"
|
||||
|
||||
@@ -7,16 +7,14 @@ import json
|
||||
import tiktoken
|
||||
import questionary
|
||||
|
||||
|
||||
from utils.style import red
|
||||
from typing import List
|
||||
from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS
|
||||
from logger.logger import logger
|
||||
from helpers.exceptions.TokenLimitError import TokenLimitError
|
||||
from utils.utils import fix_json
|
||||
|
||||
model = os.getenv('MODEL_NAME')
|
||||
endpoint = os.getenv('ENDPOINT')
|
||||
|
||||
from utils.function_calling import add_function_calls_to_request
|
||||
|
||||
def get_tokens_in_messages(messages: List[str]) -> int:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
|
||||
@@ -24,7 +22,7 @@ def get_tokens_in_messages(messages: List[str]) -> int:
|
||||
return sum(len(tokens) for tokens in tokenized_messages)
|
||||
|
||||
|
||||
def num_tokens_from_functions(functions, model=model):
|
||||
def num_tokens_from_functions(functions):
|
||||
"""Return the number of tokens used by a list of functions."""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
@@ -96,13 +94,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
|
||||
if key in gpt_data:
|
||||
del gpt_data[key]
|
||||
|
||||
if function_calls is not None:
|
||||
# Advise the LLM of the JSON response schema we are expecting
|
||||
gpt_data['functions'] = function_calls['definitions']
|
||||
if len(function_calls['definitions']) > 1:
|
||||
gpt_data['function_call'] = 'auto'
|
||||
else:
|
||||
gpt_data['function_call'] = {'name': function_calls['definitions'][0]['name']}
|
||||
add_function_calls_to_request(gpt_data, function_calls)
|
||||
|
||||
try:
|
||||
response = stream_gpt_completion(gpt_data, req_type)
|
||||
@@ -110,7 +102,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
|
||||
except TokenLimitError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print('The request to OpenAI API failed. Here is the error message:')
|
||||
print(f'The request to {os.getenv("ENDPOINT")} API failed. Here is the error message:')
|
||||
print(e)
|
||||
|
||||
|
||||
@@ -126,6 +118,7 @@ def count_lines_based_on_width(content, width):
|
||||
lines_required = sum(len(line) // width + 1 for line in content.split('\n'))
|
||||
return lines_required
|
||||
|
||||
|
||||
def get_tokens_in_messages_from_openai_error(error_message):
|
||||
"""
|
||||
Extract the token count from a message.
|
||||
@@ -208,7 +201,10 @@ def stream_gpt_completion(data, req_type):
|
||||
|
||||
logger.info(f'Request data: {data}')
|
||||
|
||||
# Check if the ENDPOINT is AZURE
|
||||
# Configure for the selected ENDPOINT
|
||||
model = os.getenv('MODEL_NAME')
|
||||
endpoint = os.getenv('ENDPOINT')
|
||||
|
||||
if endpoint == 'AZURE':
|
||||
# If yes, get the AZURE_ENDPOINT from .ENV file
|
||||
endpoint_url = os.getenv('AZURE_ENDPOINT') + '/openai/deployments/' + model + '/chat/completions?api-version=2023-05-15'
|
||||
@@ -239,10 +235,9 @@ def stream_gpt_completion(data, req_type):
|
||||
gpt_response = ''
|
||||
function_calls = {'name': '', 'arguments': ''}
|
||||
|
||||
|
||||
for line in response.iter_lines():
|
||||
# Ignore keep-alive new lines
|
||||
if line:
|
||||
if line and line != b': OPENROUTER PROCESSING':
|
||||
line = line.decode("utf-8") # decode the bytes to string
|
||||
|
||||
if line.startswith('data: '):
|
||||
@@ -262,11 +257,13 @@ def stream_gpt_completion(data, req_type):
|
||||
logger.error(f'Error in LLM response: {json_line}')
|
||||
raise ValueError(f'Error in LLM response: {json_line["error"]["message"]}')
|
||||
|
||||
if json_line['choices'][0]['finish_reason'] == 'function_call':
|
||||
choice = json_line['choices'][0]
|
||||
|
||||
if 'finish_reason' in choice and choice['finish_reason'] == 'function_call':
|
||||
function_calls['arguments'] = load_data_to_json(function_calls['arguments'])
|
||||
return return_result({'function_calls': function_calls}, lines_printed)
|
||||
|
||||
json_line = json_line['choices'][0]['delta']
|
||||
json_line = choice['delta']
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f'Unable to decode line: {line}')
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import builtins
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from const.function_calls import ARCHITECTURE
|
||||
from unittest.mock import patch
|
||||
from local_llm_function_calling.prompter import CompletionModelPrompter, InstructModelPrompter
|
||||
|
||||
from const.function_calls import ARCHITECTURE, DEV_STEPS
|
||||
from helpers.AgentConvo import AgentConvo
|
||||
from helpers.Project import Project
|
||||
from helpers.agents.Architect import Architect
|
||||
from helpers.agents.Developer import Developer
|
||||
from .llm_connection import create_gpt_chat_completion
|
||||
from main import get_custom_print
|
||||
|
||||
@@ -16,7 +21,31 @@ class TestLlmConnection:
|
||||
def setup_method(self):
|
||||
builtins.print, ipc_client_instance = get_custom_print({})
|
||||
|
||||
def test_chat_completion_Architect(self):
|
||||
# def test_break_down_development_task(self):
|
||||
# # Given
|
||||
# agent = Developer(project)
|
||||
# convo = AgentConvo(agent)
|
||||
# # convo.construct_and_add_message_from_prompt('architecture/technologies.prompt',
|
||||
# # {
|
||||
# # 'name': 'Test App',
|
||||
# # 'prompt': '''
|
||||
#
|
||||
# messages = convo.messages
|
||||
# function_calls = DEV_STEPS
|
||||
#
|
||||
# # When
|
||||
# # response = create_gpt_chat_completion(messages, '', function_calls=function_calls)
|
||||
# response = {'function_calls': {
|
||||
# 'name': 'break_down_development_task',
|
||||
# 'arguments': {'tasks': [{'type': 'command', 'description': 'Run the app'}]}
|
||||
# }}
|
||||
# response = convo.postprocess_response(response, function_calls)
|
||||
#
|
||||
# # Then
|
||||
# # assert len(convo.messages) == 2
|
||||
# assert response == ([{'type': 'command', 'description': 'Run the app'}], 'more_tasks')
|
||||
|
||||
def test_chat_completion_Architect(self, monkeypatch):
|
||||
"""Test the chat completion method."""
|
||||
# Given
|
||||
agent = Architect(project)
|
||||
@@ -49,19 +78,80 @@ class TestLlmConnection:
|
||||
})
|
||||
|
||||
messages = convo.messages
|
||||
function_calls = ARCHITECTURE
|
||||
endpoint = 'OPENROUTER'
|
||||
# monkeypatch.setattr('utils.llm_connection.endpoint', endpoint)
|
||||
monkeypatch.setenv('ENDPOINT', endpoint)
|
||||
monkeypatch.setenv('MODEL_NAME', 'meta-llama/codellama-34b-instruct')
|
||||
|
||||
# with patch('.llm_connection.endpoint', endpoint):
|
||||
# When
|
||||
response = create_gpt_chat_completion(messages, '', function_calls=ARCHITECTURE)
|
||||
response = create_gpt_chat_completion(messages, '', function_calls=function_calls)
|
||||
|
||||
# Then
|
||||
assert len(convo.messages) == 2
|
||||
assert convo.messages[0]['content'].startswith('You are an experienced software architect')
|
||||
assert convo.messages[1]['content'].startswith('You are working in a software development agency')
|
||||
assert response is not None
|
||||
assert len(response) > 0
|
||||
technologies: list[str] = response['function_calls']['arguments']['technologies']
|
||||
assert 'Node.js' in technologies
|
||||
|
||||
assert response is not None
|
||||
response = convo.postprocess_response(response, function_calls)
|
||||
# response = response['function_calls']['arguments']['technologies']
|
||||
assert 'Node.js' in response
|
||||
|
||||
def test_completion_function_prompt(self):
|
||||
# Given
|
||||
prompter = CompletionModelPrompter()
|
||||
|
||||
# When
|
||||
prompt = prompter.prompt('Create a web-based chat app', ARCHITECTURE['definitions']) # , 'process_technologies')
|
||||
|
||||
# Then
|
||||
assert prompt == '''Create a web-based chat app
|
||||
|
||||
Available functions:
|
||||
process_technologies - Print the list of technologies that are created.
|
||||
```jsonschema
|
||||
{
|
||||
"technologies": {
|
||||
"type": "array",
|
||||
"description": "List of technologies that are created in a list.",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "technology"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Function call:
|
||||
|
||||
Function call: '''
|
||||
|
||||
def test_instruct_function_prompter(self):
|
||||
# Given
|
||||
prompter = InstructModelPrompter()
|
||||
|
||||
# When
|
||||
prompt = prompter.prompt('Create a web-based chat app', ARCHITECTURE['definitions']) # , 'process_technologies')
|
||||
|
||||
# Then
|
||||
assert prompt == '''Your task is to call a function when needed. You will be provided with a list of functions. Available functions:
|
||||
process_technologies - Print the list of technologies that are created.
|
||||
```jsonschema
|
||||
{
|
||||
"technologies": {
|
||||
"type": "array",
|
||||
"description": "List of technologies that are created in a list.",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "technology"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Create a web-based chat app
|
||||
|
||||
Function call: '''
|
||||
|
||||
def _create_convo(self, agent):
|
||||
convo = AgentConvo(agent)
|
||||
Reference in New Issue
Block a user