Enabled catching of max token limit errors from OpenAI's response

This commit is contained in:
Zvonimir Sabljic
2023-09-18 19:18:54 -07:00
parent 3e20f52b8a
commit ffe4fbeba9

View File

@@ -103,13 +103,6 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
{'function_calls': {'name': str, arguments: {...}}}
"""
tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy
if function_calls is not None:
tokens_in_messages += round(
num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy
if tokens_in_messages + min_tokens > MAX_GPT_MODEL_TOKENS:
raise TokenLimitError(tokens_in_messages + min_tokens, MAX_GPT_MODEL_TOKENS)
gpt_data = {
'model': os.getenv('MODEL_NAME', 'gpt-4'),
'n': 1,
@@ -139,15 +132,11 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
try:
response = stream_gpt_completion(gpt_data, req_type)
return response
except TokenLimitError as e:
raise e
except Exception as e:
error_message = str(e)
# Check if the error message is related to token limit
if "context_length_exceeded" in error_message.lower():
raise TokenLimitError(tokens_in_messages + min_tokens, MAX_GPT_MODEL_TOKENS)
else:
print('The request to OpenAI API failed. Here is the error message:')
print(e)
print('The request to OpenAI API failed. Here is the error message:')
print(e)
def delete_last_n_lines(n):
@@ -162,6 +151,23 @@ def count_lines_based_on_width(content, width):
lines_required = sum(len(line) // width + 1 for line in content.split('\n'))
return lines_required
def get_tokens_in_messages_from_openai_error(error_message):
"""
Extract the token count from a message.
Args:
message (str): The message to extract the token count from.
Returns:
int or None: The token count if found, otherwise None.
"""
match = re.search(r"your messages resulted in (\d+) tokens", error_message)
if match:
return int(match.group(1))
else:
return None
def retry_on_exception(func):
def wrapper(*args, **kwargs):
@@ -174,7 +180,7 @@ def retry_on_exception(func):
# If the specific error "context_length_exceeded" is present, simply return without retry
if "context_length_exceeded" in err_str:
raise TokenLimitError(tokens_in_messages + min_tokens, MAX_GPT_MODEL_TOKENS)
raise TokenLimitError(get_tokens_in_messages_from_openai_error(err_str), MAX_GPT_MODEL_TOKENS)
if "rate_limit_exceeded" in err_str:
# Extracting the duration from the error string
match = re.search(r"Please try again in (\d+)ms.", err_str)