mirror of
https://github.com/OMGeeky/gpt-pilot.git
synced 2025-12-31 00:20:03 +01:00
Adde Azure OpenAI endpoint.
Tested and confirmed working.
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
#OPENAI or AZURE
|
||||
ENDPOINT=OPENAI
|
||||
OPENAI_API_KEY=
|
||||
AZURE_API_KEY=
|
||||
AZURE_ENDPOINT=
|
||||
AZURE_MODEL_NAME=
|
||||
DB_NAME=gpt-pilot
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
|
||||
@@ -46,8 +46,15 @@ def get_tokens_in_messages(messages: List[str]) -> int:
|
||||
tokenized_messages = [tokenizer.encode(message['content']) for message in messages]
|
||||
return sum(len(tokens) for tokens in tokenized_messages)
|
||||
|
||||
|
||||
def num_tokens_from_functions(functions, model="gpt-4"):
|
||||
# Check if the ENDPOINT is AZURE
|
||||
endpoint = os.getenv('ENDPOINT')
|
||||
if endpoint == 'AZURE':
|
||||
# If yes, get the model name from .ENV file
|
||||
model = os.getenv('AZURE_MODEL_NAME')
|
||||
else:
|
||||
model="gpt-4"
|
||||
|
||||
def num_tokens_from_functions(functions, model=model):
|
||||
"""Return the number of tokens used by a list of functions."""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
@@ -94,7 +101,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
|
||||
raise ValueError(f'Too many tokens in messages: {tokens_in_messages}. Please try a different test.')
|
||||
|
||||
gpt_data = {
|
||||
'model': 'gpt-4',
|
||||
'model': model,
|
||||
'n': 1,
|
||||
'max_tokens': min(4096, MAX_GPT_MODEL_TOKENS - tokens_in_messages),
|
||||
'temperature': 1,
|
||||
@@ -172,15 +179,32 @@ def stream_gpt_completion(data, req_type):
|
||||
# spinner = spinner_start(colored("Waiting for OpenAI API response...", 'yellow'))
|
||||
# print(colored("Stream response from OpenAI:", 'yellow'))
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
azure_api_key = os.getenv('AZURE_API_KEY')
|
||||
headers = {'Content-Type': 'application/json', 'api-key': azure_api_key}
|
||||
openai_endpoint = 'https://api.openai.com/v1/chat/completions'
|
||||
|
||||
logger.info(f'Request data: {data}')
|
||||
|
||||
response = requests.post(
|
||||
'https://api.openai.com/v1/chat/completions',
|
||||
headers={'Content-Type': 'application/json', 'Authorization': 'Bearer ' + api_key},
|
||||
json=data,
|
||||
stream=True
|
||||
)
|
||||
# Check if the ENDPOINT is AZURE
|
||||
if endpoint == 'AZURE':
|
||||
# If yes, get the AZURE_ENDPOINT from .ENV file
|
||||
azure_endpoint = os.getenv('AZURE_ENDPOINT')
|
||||
|
||||
# Send the request to the Azure endpoint
|
||||
response = requests.post(
|
||||
azure_endpoint + '/openai/deployments/GPT-4/chat/completions?api-version=2023-05-15',
|
||||
headers=headers,
|
||||
json=data,
|
||||
stream=True
|
||||
)
|
||||
else:
|
||||
# If not, send the request to the OpenAI endpoint
|
||||
response = requests.post(
|
||||
openai_endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
stream=True
|
||||
)
|
||||
|
||||
# Log the response status code and message
|
||||
logger.info(f'Response status code: {response.status_code}')
|
||||
|
||||
Reference in New Issue
Block a user