From ffe4fbeba9dd3d2e71853bcaec014d97ee9224bc Mon Sep 17 00:00:00 2001 From: Zvonimir Sabljic Date: Mon, 18 Sep 2023 19:18:54 -0700 Subject: [PATCH] Enabled catching of max token limit errors from OpenAI's response --- pilot/utils/llm_connection.py | 38 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/pilot/utils/llm_connection.py b/pilot/utils/llm_connection.py index 323c5c5..bb84a4c 100644 --- a/pilot/utils/llm_connection.py +++ b/pilot/utils/llm_connection.py @@ -103,13 +103,6 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO {'function_calls': {'name': str, arguments: {...}}} """ - tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy - if function_calls is not None: - tokens_in_messages += round( - num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy - if tokens_in_messages + min_tokens > MAX_GPT_MODEL_TOKENS: - raise TokenLimitError(tokens_in_messages + min_tokens, MAX_GPT_MODEL_TOKENS) - gpt_data = { 'model': os.getenv('MODEL_NAME', 'gpt-4'), 'n': 1, @@ -139,15 +132,11 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO try: response = stream_gpt_completion(gpt_data, req_type) return response + except TokenLimitError as e: + raise e except Exception as e: - error_message = str(e) - - # Check if the error message is related to token limit - if "context_length_exceeded" in error_message.lower(): - raise TokenLimitError(tokens_in_messages + min_tokens, MAX_GPT_MODEL_TOKENS) - else: - print('The request to OpenAI API failed. Here is the error message:') - print(e) + print('The request to OpenAI API failed. Here is the error message:') + print(e) def delete_last_n_lines(n): @@ -162,6 +151,23 @@ def count_lines_based_on_width(content, width): lines_required = sum(len(line) // width + 1 for line in content.split('\n')) return lines_required +def get_tokens_in_messages_from_openai_error(error_message): + """ + Extract the token count from a message. + + Args: + message (str): The message to extract the token count from. + + Returns: + int or None: The token count if found, otherwise None. + """ + + match = re.search(r"your messages resulted in (\d+) tokens", error_message) + + if match: + return int(match.group(1)) + else: + return None def retry_on_exception(func): def wrapper(*args, **kwargs): @@ -174,7 +180,7 @@ def retry_on_exception(func): # If the specific error "context_length_exceeded" is present, simply return without retry if "context_length_exceeded" in err_str: - raise TokenLimitError(tokens_in_messages + min_tokens, MAX_GPT_MODEL_TOKENS) + raise TokenLimitError(get_tokens_in_messages_from_openai_error(err_str), MAX_GPT_MODEL_TOKENS) if "rate_limit_exceeded" in err_str: # Extracting the duration from the error string match = re.search(r"Please try again in (\d+)ms.", err_str)