mirror of
https://github.com/OMGeeky/gpt-pilot.git
synced 2026-01-22 01:58:10 +01:00
get_prompt() moved from llm_connection to utils and works from unit tests
This commit is contained in:
@@ -15,6 +15,10 @@ from const.llm import MAX_QUESTIONS, END_RESPONSE
|
||||
from const.common import ROLES, STEPS
|
||||
from logger.logger import logger
|
||||
|
||||
prompts_path = os.path.join(os.path.dirname(__file__), '..', 'prompts')
|
||||
file_loader = FileSystemLoader(prompts_path)
|
||||
env = Environment(loader=file_loader)
|
||||
|
||||
|
||||
def capitalize_first_word_with_underscores(s):
|
||||
# Split the string into words based on underscores.
|
||||
@@ -29,6 +33,23 @@ def capitalize_first_word_with_underscores(s):
|
||||
return capitalized_string
|
||||
|
||||
|
||||
def get_prompt(prompt_name, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
data.update(get_prompt_components())
|
||||
|
||||
logger.debug(f"Getting prompt for {prompt_name}") # logging here
|
||||
|
||||
# Load the template
|
||||
template = env.get_template(prompt_name)
|
||||
|
||||
# Render the template with the provided data
|
||||
output = template.render(data)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_prompt_components():
|
||||
# This function reads and renders all prompts inside /prompts/components and returns them in dictionary
|
||||
|
||||
@@ -40,7 +61,8 @@ def get_prompt_components():
|
||||
}
|
||||
|
||||
# Create a FileSystemLoader
|
||||
file_loader = FileSystemLoader('prompts/components')
|
||||
prompts_path = os.path.join(os.path.dirname(__file__), '..', 'prompts/components')
|
||||
file_loader = FileSystemLoader(prompts_path)
|
||||
|
||||
# Create the Jinja2 environment
|
||||
env = Environment(loader=file_loader)
|
||||
@@ -63,17 +85,7 @@ def get_prompt_components():
|
||||
|
||||
|
||||
def get_sys_message(role):
|
||||
# Create a FileSystemLoader
|
||||
file_loader = FileSystemLoader('prompts/system_messages')
|
||||
|
||||
# Create the Jinja2 environment
|
||||
env = Environment(loader=file_loader)
|
||||
|
||||
# Load the template
|
||||
template = env.get_template(f'{role}.prompt')
|
||||
|
||||
# Render the template with no variables
|
||||
content = template.render()
|
||||
content = get_prompt(f'system_messages/{role}.prompt')
|
||||
|
||||
return {
|
||||
"role": "system",
|
||||
@@ -186,4 +198,4 @@ def json_serial(obj):
|
||||
elif isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
else:
|
||||
return str(obj)
|
||||
return str(obj)
|
||||
|
||||
Reference in New Issue
Block a user