diff --git a/pilot/.env.example b/pilot/.env.example index d53210e..5041cce 100644 --- a/pilot/.env.example +++ b/pilot/.env.example @@ -1,4 +1,9 @@ +#OPENAI or AZURE +ENDPOINT=OPENAI OPENAI_API_KEY= +AZURE_API_KEY= +AZURE_ENDPOINT= +AZURE_MODEL_NAME= DB_NAME=gpt-pilot DB_HOST=localhost DB_PORT=5432 diff --git a/pilot/utils/llm_connection.py b/pilot/utils/llm_connection.py index c191d8e..cace0d5 100644 --- a/pilot/utils/llm_connection.py +++ b/pilot/utils/llm_connection.py @@ -46,8 +46,15 @@ def get_tokens_in_messages(messages: List[str]) -> int: tokenized_messages = [tokenizer.encode(message['content']) for message in messages] return sum(len(tokens) for tokens in tokenized_messages) - -def num_tokens_from_functions(functions, model="gpt-4"): +# Check if the ENDPOINT is AZURE +endpoint = os.getenv('ENDPOINT') +if endpoint == 'AZURE': + # If yes, get the model name from .ENV file + model = os.getenv('AZURE_MODEL_NAME') +else: + model="gpt-4" + +def num_tokens_from_functions(functions, model=model): """Return the number of tokens used by a list of functions.""" encoding = tiktoken.get_encoding("cl100k_base") @@ -94,7 +101,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO raise ValueError(f'Too many tokens in messages: {tokens_in_messages}. Please try a different test.') gpt_data = { - 'model': 'gpt-4', + 'model': model, 'n': 1, 'max_tokens': min(4096, MAX_GPT_MODEL_TOKENS - tokens_in_messages), 'temperature': 1, @@ -172,15 +179,32 @@ def stream_gpt_completion(data, req_type): # spinner = spinner_start(colored("Waiting for OpenAI API response...", 'yellow')) # print(colored("Stream response from OpenAI:", 'yellow')) api_key = os.getenv("OPENAI_API_KEY") + azure_api_key = os.getenv('AZURE_API_KEY') + headers = {'Content-Type': 'application/json', 'api-key': azure_api_key} + openai_endpoint = 'https://api.openai.com/v1/chat/completions' logger.info(f'Request data: {data}') - response = requests.post( - 'https://api.openai.com/v1/chat/completions', - headers={'Content-Type': 'application/json', 'Authorization': 'Bearer ' + api_key}, - json=data, - stream=True - ) + # Check if the ENDPOINT is AZURE + if endpoint == 'AZURE': + # If yes, get the AZURE_ENDPOINT from .ENV file + azure_endpoint = os.getenv('AZURE_ENDPOINT') + + # Send the request to the Azure endpoint + response = requests.post( + azure_endpoint + '/openai/deployments/GPT-4/chat/completions?api-version=2023-05-15', + headers=headers, + json=data, + stream=True + ) + else: + # If not, send the request to the OpenAI endpoint + response = requests.post( + openai_endpoint, + headers=headers, + json=data, + stream=True + ) # Log the response status code and message logger.info(f'Response status code: {response.status_code}')