Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: load model card json #2246

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
abstraact load json method
  • Loading branch information
Minamiyama committed Sep 6, 2024
commit 18ba812683f34f0d30df7d8cd3b9d4b4853bf3c1
108 changes: 32 additions & 76 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,73 +195,52 @@ def _install():
SUPPORTED_ENGINES["MLX"] = MLX_CLASSES
SUPPORTED_ENGINES["LMDEPLOY"] = LMDEPLOY_CLASSES

json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
)
for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
model_spec = LLMFamilyV1.parse_obj(json_obj)
BUILTIN_LLM_FAMILIES.append(model_spec)
_load_from_json("llm_family.json", BUILTIN_LLM_FAMILIES)
_load_from_json("llm_family_modelscope.json", BUILTIN_MODELSCOPE_LLM_FAMILIES)
_load_from_json("llm_family_csghub.json", BUILTIN_CSGHUB_LLM_FAMILIES)

# register chat_template
if "chat" in model_spec.model_ability and isinstance(
model_spec.chat_template, str
):
# note that the key is the model name,
# since there are multiple representations of the same prompt style name in json.
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
"chat_template": model_spec.chat_template,
"stop_token_ids": model_spec.stop_token_ids,
"stop": model_spec.stop,
}
# register model family
if "chat" in model_spec.model_ability:
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
else:
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
if "tools" in model_spec.model_ability:
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
for llm_specs in [
BUILTIN_LLM_FAMILIES,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
BUILTIN_CSGHUB_LLM_FAMILIES,
]:
for llm_spec in llm_specs:
if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))

modelscope_json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
)
for json_obj in json.load(codecs.open(modelscope_json_path, "r", encoding="utf-8")):
model_spec = LLMFamilyV1.parse_obj(json_obj)
BUILTIN_MODELSCOPE_LLM_FAMILIES.append(model_spec)
# traverse all families and add engine parameters corresponding to the model name
for families in [
BUILTIN_LLM_FAMILIES,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
BUILTIN_CSGHUB_LLM_FAMILIES,
]:
for family in families:
generate_engine_config_by_model_family(family)

# register prompt style, in case that we have something missed
# if duplicated with huggingface json, keep it as the huggingface style
if (
"chat" in model_spec.model_ability
and isinstance(model_spec.chat_template, str)
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
):
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
"chat_template": model_spec.chat_template,
"stop_token_ids": model_spec.stop_token_ids,
"stop": model_spec.stop,
}
# register model family
if "chat" in model_spec.model_ability:
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
else:
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
if "tools" in model_spec.model_ability:
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
register_custom_model()

# register model description
for ud_llm in get_user_defined_llm_families():
LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(ud_llm))

csghub_json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json"
def _load_from_json(file_name: str, LLM_FAMILIES: list[LLMFamilyV1]):
json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), file_name
)
for json_obj in json.load(codecs.open(csghub_json_path, "r", encoding="utf-8")):
for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
model_spec = LLMFamilyV1.parse_obj(json_obj)
BUILTIN_CSGHUB_LLM_FAMILIES.append(model_spec)
LLM_FAMILIES.append(model_spec)

# register chat_template
# register prompt style, in case that we have something missed
# if duplicated with huggingface json, keep it as the huggingface style
if (
"chat" in model_spec.model_ability
and isinstance(model_spec.chat_template, str)
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
):
# note that the key is the model name,
# since there are multiple representations of the same prompt style name in json.
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
"chat_template": model_spec.chat_template,
"stop_token_ids": model_spec.stop_token_ids,
Expand All @@ -275,26 +254,3 @@ def _install():
if "tools" in model_spec.model_ability:
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)

for llm_specs in [
BUILTIN_LLM_FAMILIES,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
BUILTIN_CSGHUB_LLM_FAMILIES,
]:
for llm_spec in llm_specs:
if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))

# traverse all families and add engine parameters corresponding to the model name
for families in [
BUILTIN_LLM_FAMILIES,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
BUILTIN_CSGHUB_LLM_FAMILIES,
]:
for family in families:
generate_engine_config_by_model_family(family)

register_custom_model()

# register model description
for ud_llm in get_user_defined_llm_families():
LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(ud_llm))