Skip to content

Commit

Permalink
Merge pull request Eladlev#55 from vincilee2/main
Browse files Browse the repository at this point in the history
fix issues with Azure Openai endpoint
  • Loading branch information
Eladlev authored Apr 1, 2024
2 parents cdddccf + c4bdd60 commit 7f373f2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pandas==1.5.3
tqdm==4.66.1
prodict==0.8.18
langchain==0.1.9
openai==0.28.0
openai==1.1.0
tiktoken==0.5.1
easydict==1.11
wandb==0.16.0
Expand Down
16 changes: 9 additions & 7 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ def get_llm(config: dict):
if config['type'] == 'OpenAI':
if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '':
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=LLM_ENV['openai']['OPENAI_API_KEY'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'),
model_kwargs=model_kwargs)
else:
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=LLM_ENV['openai']['OPENAI_API_KEY'],
openai_organization=LLM_ENV['openai']['OPENAI_ORGANIZATION'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'),
openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']),
model_kwargs=model_kwargs)
elif config['type'] == 'Azure':
return AzureChatOpenAI(temperature=temperature, deployment_name=config['name'],
openai_api_key=LLM_ENV['azure']['AZURE_OPENAI_API_KEY'],
azure_endpoint=LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT'],
openai_api_version=LLM_ENV['azure']['OPENAI_API_VERSION'])
return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']),
azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']),
openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION']))

elif config['type'] == 'Google':
from langchain_google_genai import ChatGoogleGenerativeAI
Expand Down
2 changes: 1 addition & 1 deletion utils/llm_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_chain(self):
"""
Build the chain according to the LLM type
"""
if self.llm_config.type == 'OpenAI' and self.json_schema is not None:
if (self.llm_config.type == 'OpenAI' or self.llm_config.type == 'Azure') and self.json_schema is not None:
self.chain = create_structured_output_runnable(self.json_schema, self.llm, self.prompt)
else:
self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
Expand Down

0 comments on commit 7f373f2

Please sign in to comment.