get_prompt() moved from llm_connection to utils and works from unit tests

This commit is contained in:
Nicholas Albion
2023-09-20 22:17:37 +10:00
parent 0234c5f7e1
commit 4b7aa2df22
5 changed files with 32 additions and 43 deletions

View File

@@ -15,6 +15,10 @@ from const.llm import MAX_QUESTIONS, END_RESPONSE
from const.common import ROLES, STEPS
from logger.logger import logger
prompts_path = os.path.join(os.path.dirname(__file__), '..', 'prompts')
file_loader = FileSystemLoader(prompts_path)
env = Environment(loader=file_loader)
def capitalize_first_word_with_underscores(s):
# Split the string into words based on underscores.
@@ -29,6 +33,23 @@ def capitalize_first_word_with_underscores(s):
return capitalized_string
def get_prompt(prompt_name, data=None):
if data is None:
data = {}
data.update(get_prompt_components())
logger.debug(f"Getting prompt for {prompt_name}") # logging here
# Load the template
template = env.get_template(prompt_name)
# Render the template with the provided data
output = template.render(data)
return output
def get_prompt_components():
# This function reads and renders all prompts inside /prompts/components and returns them in dictionary
@@ -40,7 +61,8 @@ def get_prompt_components():
}
# Create a FileSystemLoader
file_loader = FileSystemLoader('prompts/components')
prompts_path = os.path.join(os.path.dirname(__file__), '..', 'prompts/components')
file_loader = FileSystemLoader(prompts_path)
# Create the Jinja2 environment
env = Environment(loader=file_loader)
@@ -63,17 +85,7 @@ def get_prompt_components():
def get_sys_message(role):
# Create a FileSystemLoader
file_loader = FileSystemLoader('prompts/system_messages')
# Create the Jinja2 environment
env = Environment(loader=file_loader)
# Load the template
template = env.get_template(f'{role}.prompt')
# Render the template with no variables
content = template.render()
content = get_prompt(f'system_messages/{role}.prompt')
return {
"role": "system",
@@ -186,4 +198,4 @@ def json_serial(obj):
elif isinstance(obj, uuid.UUID):
return str(obj)
else:
return str(obj)
return str(obj)