Skip to content

Commit

Permalink
feat: make the config more robust to provider names
Browse files Browse the repository at this point in the history
  • Loading branch information
Omer-ler committed Jul 1, 2024
1 parent 0a7fcce commit 663208e
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,44 @@ def get_llm(config: dict):
model_kwargs = config['model_kwargs']
else:
model_kwargs = {}
if config['type'] == 'OpenAI':
if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '':
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', LLM_ENV['openai']['OPENAI_API_BASE']),
model_kwargs=model_kwargs)
else:
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'),
openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']),
model_kwargs=model_kwargs)
elif config['type'] == 'Azure':
return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']),
azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']),
openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION']))

elif config['type'] == 'Google':
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(temperature=temperature, model=config['name'],
google_api_key=LLM_ENV['google']['GOOGLE_API_KEY'],
model_kwargs=model_kwargs)


elif config['type'] == 'HuggingFacePipeline':
device = config.get('gpu_device', -1)
device_map = config.get('device_map', None)

return HuggingFacePipeline.from_model_id(
model_id=config['name'],
task="text-generation",
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']},
device=device,
device_map=device_map
)
else:
raise NotImplementedError("LLM not implemented")
match config['type'].lower():
case 'openai':
if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '':
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', LLM_ENV['openai']['OPENAI_API_BASE']),
model_kwargs=model_kwargs)
else:
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'),
openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']),
model_kwargs=model_kwargs)
case 'azure':
return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']),
azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']),
openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION']))
case 'google':
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(temperature=temperature, model=config['name'],
google_api_key=LLM_ENV['google']['GOOGLE_API_KEY'],
model_kwargs=model_kwargs)

case 'huggingfacepipeline':

device = config.get('gpu_device', -1)
device_map = config.get('device_map', None)

return HuggingFacePipeline.from_model_id(
model_id=config['name'],
task="text-generation",
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']},
device=device,
device_map=device_map
)
case _:
raise NotImplementedError("LLM not implemented")


def load_yaml(yaml_path: str, as_edict: bool = True) -> edict:
Expand Down

0 comments on commit 663208e

Please sign in to comment.