This commit is contained in:
LeonOstrez
2023-08-04 16:57:05 +02:00
2 changed files with 43 additions and 10 deletions

View File

@@ -1,4 +1,4 @@
MIN_TOKENS_FOR_GPT_RESPONSE = 60
MAX_GPT_MODEL_TOKENS = 4096
MAX_GPT_MODEL_TOKENS = 8192
MAX_QUESTIONS = 3
END_RESPONSE = "EVERYTHING_CLEAR"

View File

@@ -1,8 +1,7 @@
# llm_connection_old.py
import requests
import os
import json
# from tiktoken import Tokenizer
import tiktoken
from typing import List
from jinja2 import Environment, FileSystemLoader
@@ -40,16 +39,50 @@ def get_prompt(prompt_name, data=None):
def get_tokens_in_messages(messages: List[str]) -> int:
tokenizer = Tokenizer()
tokenized_messages = [tokenizer.encode(message) for message in messages]
tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
tokenized_messages = [tokenizer.encode(message['content']) for message in messages]
return sum(len(tokens) for tokens in tokenized_messages)
def num_tokens_from_functions(functions, model="gpt-4"):
"""Return the number of tokens used by a list of functions."""
encoding = tiktoken.get_encoding("cl100k_base")
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE,
function_calls=None):
api_key = os.getenv("OPENAI_API_KEY")
# tokens_in_messages = get_tokens_in_messages(messages)
tokens_in_messages = 100
num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function['name']))
function_tokens += len(encoding.encode(function['description']))
if 'parameters' in function:
parameters = function['parameters']
if 'properties' in parameters:
for propertiesKey in parameters['properties']:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters['properties'][propertiesKey]
for field in v:
if field == 'type':
function_tokens += 2
function_tokens += len(encoding.encode(v['type']))
elif field == 'description':
function_tokens += 2
function_tokens += len(encoding.encode(v['description']))
elif field == 'enum':
function_tokens -= 3
for o in v['enum']:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
print(f"Warning: not supported field {field}")
function_tokens += 11
num_tokens += function_tokens
num_tokens += 12
return num_tokens
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE, function_calls=None):
tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy
if function_calls is not None:
tokens_in_messages += round(num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy
if tokens_in_messages + min_tokens > MAX_GPT_MODEL_TOKENS:
raise ValueError(f'Too many tokens in messages: {tokens_in_messages}. Please try a different test.')