Adde Azure OpenAI endpoint.

Tested and confirmed working.
This commit is contained in:
Sander Hilven
2023-09-01 09:53:17 +02:00
parent 4638a209b2
commit 984379fe71
2 changed files with 38 additions and 9 deletions

View File

@@ -1,4 +1,9 @@
#OPENAI or AZURE
ENDPOINT=OPENAI
OPENAI_API_KEY=
AZURE_API_KEY=
AZURE_ENDPOINT=
AZURE_MODEL_NAME=
DB_NAME=gpt-pilot
DB_HOST=localhost
DB_PORT=5432

View File

@@ -46,8 +46,15 @@ def get_tokens_in_messages(messages: List[str]) -> int:
tokenized_messages = [tokenizer.encode(message['content']) for message in messages]
return sum(len(tokens) for tokens in tokenized_messages)
def num_tokens_from_functions(functions, model="gpt-4"):
# Check if the ENDPOINT is AZURE
endpoint = os.getenv('ENDPOINT')
if endpoint == 'AZURE':
# If yes, get the model name from .ENV file
model = os.getenv('AZURE_MODEL_NAME')
else:
model="gpt-4"
def num_tokens_from_functions(functions, model=model):
"""Return the number of tokens used by a list of functions."""
encoding = tiktoken.get_encoding("cl100k_base")
@@ -94,7 +101,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
raise ValueError(f'Too many tokens in messages: {tokens_in_messages}. Please try a different test.')
gpt_data = {
'model': 'gpt-4',
'model': model,
'n': 1,
'max_tokens': min(4096, MAX_GPT_MODEL_TOKENS - tokens_in_messages),
'temperature': 1,
@@ -172,15 +179,32 @@ def stream_gpt_completion(data, req_type):
# spinner = spinner_start(colored("Waiting for OpenAI API response...", 'yellow'))
# print(colored("Stream response from OpenAI:", 'yellow'))
api_key = os.getenv("OPENAI_API_KEY")
azure_api_key = os.getenv('AZURE_API_KEY')
headers = {'Content-Type': 'application/json', 'api-key': azure_api_key}
openai_endpoint = 'https://api.openai.com/v1/chat/completions'
logger.info(f'Request data: {data}')
response = requests.post(
'https://api.openai.com/v1/chat/completions',
headers={'Content-Type': 'application/json', 'Authorization': 'Bearer ' + api_key},
json=data,
stream=True
)
# Check if the ENDPOINT is AZURE
if endpoint == 'AZURE':
# If yes, get the AZURE_ENDPOINT from .ENV file
azure_endpoint = os.getenv('AZURE_ENDPOINT')
# Send the request to the Azure endpoint
response = requests.post(
azure_endpoint + '/openai/deployments/GPT-4/chat/completions?api-version=2023-05-15',
headers=headers,
json=data,
stream=True
)
else:
# If not, send the request to the OpenAI endpoint
response = requests.post(
openai_endpoint,
headers=headers,
json=data,
stream=True
)
# Log the response status code and message
logger.info(f'Response status code: {response.status_code}')