diff --git a/euclid/utils/llm_connection.py b/euclid/utils/llm_connection.py index e331f5f..e4a30c6 100644 --- a/euclid/utils/llm_connection.py +++ b/euclid/utils/llm_connection.py @@ -8,7 +8,7 @@ from jinja2 import Environment, FileSystemLoader from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS, MAX_QUESTIONS, END_RESPONSE from logger.logger import logger from termcolor import colored -from utils.utils import get_prompt_components, fix_json_newlines +from utils.utils import get_prompt_components, fix_json from utils.spinner import spinner_start, spinner_stop @@ -90,6 +90,10 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO 'model': 'gpt-4', 'n': 1, 'max_tokens': min(4096, MAX_GPT_MODEL_TOKENS - tokens_in_messages), + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, 'messages': messages, 'stream': True } @@ -152,15 +156,17 @@ def stream_gpt_completion(data, req_type): continue try: - json_line = load_data_to_json(line) + json_line = json.loads(line) if 'error' in json_line: logger.error(f'Error in LLM response: {json_line}') raise ValueError(f'Error in LLM response: {json_line["error"]["message"]}') + if json_line['choices'][0]['finish_reason'] == 'function_call': function_calls['arguments'] = load_data_to_json(function_calls['arguments']) return return_result({'function_calls': function_calls}); json_line = json_line['choices'][0]['delta'] + except json.JSONDecodeError: logger.error(f'Unable to decode line: {line}') continue # skip to the next line @@ -195,4 +201,4 @@ def postprocessing(gpt_response, req_type): def load_data_to_json(string): - return json.loads(fix_json_newlines(string)) + return json.loads(fix_json(string)) diff --git a/euclid/utils/utils.py b/euclid/utils/utils.py index 6fea50d..771dafb 100644 --- a/euclid/utils/utils.py +++ b/euclid/utils/utils.py @@ -145,6 +145,12 @@ def replace_functions(obj): else: return obj +def fix_json(s): + s = s.replace('True', 'true') + s = s.replace('False', 'false') + # s = s.replace('`', '"') + return fix_json_newlines(s) + def fix_json_newlines(s): pattern = r'("(?:\\\\n|\\.|[^"\\])*")'