mirror of
https://github.com/OMGeeky/gpt-pilot.git
synced 2026-01-04 10:20:21 +01:00
ARCHITECTURE function_calls works on meta-llama/codellama-34b-instruct
This commit is contained in:
169
pilot/utils/function_calling.py
Normal file
169
pilot/utils/function_calling.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
# from local_llm_function_calling import Generator
|
||||
# from local_llm_function_calling.model.llama import LlamaModel
|
||||
# from local_llm_function_calling.model.huggingface import HuggingfaceModel
|
||||
from local_llm_function_calling.prompter import FunctionType, CompletionModelPrompter, InstructModelPrompter
|
||||
# from local_llm_function_calling.model.llama import LlamaInstructPrompter
|
||||
|
||||
from typing import Literal, NotRequired, Protocol, TypeVar, TypedDict, Callable
|
||||
|
||||
|
||||
class FunctionCallSet(TypedDict):
|
||||
definitions: list[FunctionType]
|
||||
functions: dict[str, Callable]
|
||||
|
||||
|
||||
def add_function_calls_to_request(gpt_data, function_calls: FunctionCallSet | None):
|
||||
if function_calls is None:
|
||||
return
|
||||
|
||||
if gpt_data['model'] == 'gpt-4':
|
||||
gpt_data['functions'] = function_calls['definitions']
|
||||
if len(function_calls['definitions']) > 1:
|
||||
gpt_data['function_call'] = 'auto'
|
||||
else:
|
||||
gpt_data['function_call'] = {'name': function_calls['definitions'][0]['name']}
|
||||
return
|
||||
|
||||
# prompter = CompletionModelPrompter()
|
||||
# prompter = InstructModelPrompter()
|
||||
prompter = LlamaInstructPrompter()
|
||||
|
||||
if len(function_calls['definitions']) > 1:
|
||||
function_call = None
|
||||
else:
|
||||
function_call = function_calls['definitions'][0]['name']
|
||||
|
||||
gpt_data['messages'].append({
|
||||
'role': 'user',
|
||||
'content': prompter.prompt('', function_calls['definitions'], function_call)
|
||||
})
|
||||
|
||||
|
||||
class LlamaInstructPrompter:
|
||||
"""
|
||||
A prompter for Llama2 instruct models.
|
||||
Adapted from local_llm_function_calling
|
||||
"""
|
||||
|
||||
def function_descriptions(
|
||||
self, functions: list[FunctionType], function_to_call: str
|
||||
) -> list[str]:
|
||||
"""Get the descriptions of the functions
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the descriptions of
|
||||
function_to_call (str): The function to call
|
||||
|
||||
Returns:
|
||||
list[str]: The descriptions of the functions
|
||||
(empty if the function doesn't exist or has no description)
|
||||
"""
|
||||
return [
|
||||
"Function description: " + function["description"]
|
||||
for function in functions
|
||||
if function["name"] == function_to_call and "description" in function
|
||||
]
|
||||
|
||||
def function_parameters(
|
||||
self, functions: list[FunctionType], function_to_call: str
|
||||
) -> str:
|
||||
"""Get the parameters of the function
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the parameters of
|
||||
function_to_call (str): The function to call
|
||||
|
||||
Returns:
|
||||
str: The parameters of the function as a JSON schema
|
||||
"""
|
||||
return next(
|
||||
json.dumps(function["parameters"]["properties"], indent=4)
|
||||
for function in functions
|
||||
if function["name"] == function_to_call
|
||||
)
|
||||
|
||||
def function_data(
|
||||
self, functions: list[FunctionType], function_to_call: str
|
||||
) -> str:
|
||||
"""Get the data for the function
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the data for
|
||||
function_to_call (str): The function to call
|
||||
|
||||
Returns:
|
||||
str: The data necessary to generate the arguments for the function
|
||||
"""
|
||||
return "\n".join(
|
||||
self.function_descriptions(functions, function_to_call)
|
||||
+ [
|
||||
"Function parameters should follow this schema:",
|
||||
"```jsonschema",
|
||||
self.function_parameters(functions, function_to_call),
|
||||
"```",
|
||||
]
|
||||
)
|
||||
|
||||
def function_summary(self, function: FunctionType) -> str:
|
||||
"""Get a summary of a function
|
||||
|
||||
Args:
|
||||
function (FunctionType): The function to get the summary of
|
||||
|
||||
Returns:
|
||||
str: The summary of the function, as a bullet point
|
||||
"""
|
||||
return f"- {function['name']}" + (
|
||||
f" - {function['description']}" if "description" in function else ""
|
||||
)
|
||||
|
||||
def functions_summary(self, functions: list[FunctionType]) -> str:
|
||||
"""Get a summary of the functions
|
||||
|
||||
Args:
|
||||
functions (list[FunctionType]): The functions to get the summary of
|
||||
|
||||
Returns:
|
||||
str: The summary of the functions, as a bulleted list
|
||||
"""
|
||||
return "Available functions:\n" + "\n".join(
|
||||
self.function_summary(function) for function in functions
|
||||
)
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
functions: list[FunctionType],
|
||||
function_to_call: str | None = None,
|
||||
) -> str:
|
||||
"""Generate the llama prompt
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate the response to
|
||||
functions (list[FunctionType]): The functions to generate the response from
|
||||
function_to_call (str | None): The function to call. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[bytes | int]: The llama prompt, a function selection prompt if no
|
||||
function is specified, or a function argument prompt if a function is
|
||||
specified
|
||||
"""
|
||||
system = (
|
||||
"Help choose the appropriate function to call to answer the user's question."
|
||||
if function_to_call is None
|
||||
else f"Define the arguments for {function_to_call} to answer the user's question."
|
||||
) + "In your response you must only use JSON output and provide no notes or commentary."
|
||||
data = (
|
||||
self.function_data(functions, function_to_call)
|
||||
if function_to_call
|
||||
else self.functions_summary(functions)
|
||||
)
|
||||
response_start = (
|
||||
f"Here are the arguments for the `{function_to_call}` function: ```json\n"
|
||||
if function_to_call
|
||||
else "Here's the function the user should call: "
|
||||
)
|
||||
return f"[INST] <<SYS>>\n{system}\n\n{data}\n<</SYS>>\n\n{prompt} [/INST]"
|
||||
# {response_start}"
|
||||
|
||||
Reference in New Issue
Block a user