get_prompt() moved from llm_connection to utils and works from unit tests

This commit is contained in:
Nicholas Albion
2023-09-20 22:17:37 +10:00
parent 0234c5f7e1
commit 4b7aa2df22
5 changed files with 32 additions and 43 deletions

View File

@@ -7,8 +7,8 @@ from database.database import get_saved_development_step, save_development_step,
from helpers.files import get_files_content
from const.common import IGNORE_FOLDERS
from helpers.exceptions.TokenLimitError import TokenLimitError
from utils.utils import array_of_objects_to_string
from utils.llm_connection import get_prompt, create_gpt_chat_completion
from utils.utils import array_of_objects_to_string, get_prompt
from utils.llm_connection import create_gpt_chat_completion
from utils.utils import get_sys_message, find_role_from_step, capitalize_first_word_with_underscores
from logger.logger import logger
from prompts.prompts import ask_user

View File

@@ -11,7 +11,7 @@ from logger.logger import logger
from helpers.Agent import Agent
from helpers.AgentConvo import AgentConvo
from utils.utils import should_execute_step, array_of_objects_to_string, generate_app_data
from helpers.cli import run_command_until_success, execute_command_and_check_cli_response, debug
from helpers.cli import run_command_until_success, execute_command_and_check_cli_response
from const.function_calls import FILTER_OS_TECHNOLOGIES, EXECUTE_COMMANDS, GET_TEST_TYPE, IMPLEMENT_TASK
from database.database import save_progress, get_progress_steps
from utils.utils import get_os_info

View File

@@ -2,8 +2,8 @@
from utils.style import yellow
from const import common
from const.llm import MAX_QUESTIONS, END_RESPONSE
from utils.llm_connection import create_gpt_chat_completion, get_prompt
from utils.utils import capitalize_first_word_with_underscores, get_sys_message, find_role_from_step
from utils.llm_connection import create_gpt_chat_completion
from utils.utils import capitalize_first_word_with_underscores, get_sys_message, find_role_from_step, get_prompt
from utils.questionary import styled_select, styled_text
from logger.logger import logger

View File

@@ -9,35 +9,12 @@ import questionary
from utils.style import red
from typing import List
from jinja2 import Environment, FileSystemLoader
from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS, MAX_QUESTIONS, END_RESPONSE
from const.llm import MIN_TOKENS_FOR_GPT_RESPONSE, MAX_GPT_MODEL_TOKENS
from logger.logger import logger
from helpers.exceptions.TokenLimitError import TokenLimitError
from utils.utils import get_prompt_components, fix_json
from utils.spinner import spinner_start, spinner_stop
from utils.utils import fix_json
def get_prompt(prompt_name, data=None):
if data is None:
data = {}
data.update(get_prompt_components())
logger.debug(f"Getting prompt for {prompt_name}") # logging here
# Create a file system loader with the directory of the templates
file_loader = FileSystemLoader('prompts')
# Create the Jinja2 environment
env = Environment(loader=file_loader)
# Load the template
template = env.get_template(prompt_name)
# Render the template with the provided data
output = template.render(data)
return output
def get_tokens_in_messages(messages: List[str]) -> int:

View File

@@ -15,6 +15,10 @@ from const.llm import MAX_QUESTIONS, END_RESPONSE
from const.common import ROLES, STEPS
from logger.logger import logger
prompts_path = os.path.join(os.path.dirname(__file__), '..', 'prompts')
file_loader = FileSystemLoader(prompts_path)
env = Environment(loader=file_loader)
def capitalize_first_word_with_underscores(s):
# Split the string into words based on underscores.
@@ -29,6 +33,23 @@ def capitalize_first_word_with_underscores(s):
return capitalized_string
def get_prompt(prompt_name, data=None):
if data is None:
data = {}
data.update(get_prompt_components())
logger.debug(f"Getting prompt for {prompt_name}") # logging here
# Load the template
template = env.get_template(prompt_name)
# Render the template with the provided data
output = template.render(data)
return output
def get_prompt_components():
# This function reads and renders all prompts inside /prompts/components and returns them in dictionary
@@ -40,7 +61,8 @@ def get_prompt_components():
}
# Create a FileSystemLoader
file_loader = FileSystemLoader('prompts/components')
prompts_path = os.path.join(os.path.dirname(__file__), '..', 'prompts/components')
file_loader = FileSystemLoader(prompts_path)
# Create the Jinja2 environment
env = Environment(loader=file_loader)
@@ -63,17 +85,7 @@ def get_prompt_components():
def get_sys_message(role):
# Create a FileSystemLoader
file_loader = FileSystemLoader('prompts/system_messages')
# Create the Jinja2 environment
env = Environment(loader=file_loader)
# Load the template
template = env.get_template(f'{role}.prompt')
# Render the template with no variables
content = template.render()
content = get_prompt(f'system_messages/{role}.prompt')
return {
"role": "system",
@@ -186,4 +198,4 @@ def json_serial(obj):
elif isinstance(obj, uuid.UUID):
return str(obj)
else:
return str(obj)
return str(obj)