diff --git a/euclid/utils/llm_connection.py b/euclid/utils/llm_connection.py index d229ba7..771f978 100644 --- a/euclid/utils/llm_connection.py +++ b/euclid/utils/llm_connection.py @@ -8,7 +8,7 @@ from jinja2 import Environment, FileSystemLoader from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS, MAX_QUESTIONS, END_RESPONSE from logger.logger import logger from termcolor import colored -from utils.utils import get_prompt_components, escape_json_special_chars +from utils.utils import get_prompt_components, fix_json_newlines from utils.spinner import spinner_start, spinner_stop @@ -152,9 +152,9 @@ def stream_gpt_completion(data, req_type): continue try: - json_line = json_loads_with_escape(line) + json_line = load_data_to_json(line) if json_line['choices'][0]['finish_reason'] == 'function_call': - function_calls['arguments'] = json_loads_with_escape(function_calls['arguments']) + function_calls['arguments'] = load_data_to_json(function_calls['arguments']) return return_result({'function_calls': function_calls}); json_line = json_line['choices'][0]['delta'] @@ -178,7 +178,7 @@ def stream_gpt_completion(data, req_type): print('\n') if function_calls['arguments'] != '': logger.info(f'Response via function call: {function_calls["arguments"]}') - function_calls['arguments'] = json_loads_with_escape(function_calls['arguments']) + function_calls['arguments'] = load_data_to_json(function_calls['arguments']) return return_result({'function_calls': function_calls}); logger.info(f'Response message: {gpt_response}') new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically @@ -189,6 +189,5 @@ def postprocessing(gpt_response, req_type): return gpt_response -def json_loads_with_escape(str): - # return json.loads(escape_json_special_chars(str)) - return json.loads(str) +def load_data_to_json(string): + return json.loads(fix_json_newlines(string)) diff --git a/euclid/utils/utils.py b/euclid/utils/utils.py index 5f4afcb..74cfa5e 100644 --- a/euclid/utils/utils.py +++ b/euclid/utils/utils.py @@ -182,21 +182,13 @@ def replace_functions(obj): else: return obj -def escape_json_special_chars(s): - replacements = { - '"': '\\"', - '\\': '\\\\', - '\n': '\\n', - '\r': '\\r', - '\t': '\\t', - '\b': '\\b', - '\f': '\\f' - } +def fix_json_newlines(s): + pattern = r'("(?:\\.|[^"\\])*")' - for char, replacement in replacements.items(): - s = s.replace(char, replacement) + def replace_newlines(match): + return match.group(1).replace('\n', '\\n') - return s + return re.sub(pattern, replace_newlines, s) def clean_filename(filename): # Remove invalid characters