-
Notifications
You must be signed in to change notification settings - Fork 2
/
publish_hf.py
141 lines (118 loc) · 4.34 KB
/
publish_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Prompt to push each variation of our models to the HuggingFace Hub.
import argparse
import os
import yaml
from huggingface_hub import ModelCardData, ModelCard
from transformers import PreTrainedModel
from model.model import EHRAuditGPT2, EHRAuditRWKV, EHRAuditLlama
from model.vocab import EHRVocab
import huggingface_hub
parser = argparse.ArgumentParser()
parser.add_argument(
"--debug",
action="store_true",
help="Dry run, don't actually push to HF",
)
if __name__ == "__main__":
args = parser.parse_args()
# Load configuration and vocab
config_path = os.path.normpath(
os.path.join(os.path.dirname(__file__), "config.yaml")
)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
path_prefix = ""
for prefix in config["path_prefix"]:
if os.path.exists(prefix):
path_prefix = prefix
break
model_paths = os.path.normpath(
os.path.join(path_prefix, config["pretrained_model_path"])
)
# Get recursive list of subdirectories
model_list = []
for root, dirs, files in os.walk(model_paths):
# If there's a .bin file, it's a model
if any([file.endswith(".bin") for file in files]):
# Append the last three directories to the model list
model_list.append(os.path.join(*root.split(os.sep)[-3:]))
if len(model_list) == 0:
raise ValueError(f"No models found in {format(model_paths)}")
vocab = EHRVocab(
vocab_path=os.path.normpath(os.path.join(path_prefix, config["vocab_path"]))
)
types = {
"gpt2": EHRAuditGPT2,
"rwkv": EHRAuditRWKV,
"llama": EHRAuditLlama,
}
api = huggingface_hub.HfApi()
# Iterate through and push each model to the Hub
for model_idx, model_name in enumerate(model_list):
# Load the model
model_path = os.path.normpath(os.path.join(model_paths, model_name))
model_props = model_name.split(os.sep)
model_type = model_props[0]
model_params = model_props[1]
model_date = model_props[2]
hf_name = "-".join([config["huggingface"]["prefix"], model_type, model_params])
hf_repo = config["huggingface"]["username"] + "/" + hf_name
github_link = config["huggingface"]["github_link"]
arxiv_link = config["huggingface"]["arxiv_link"]
print(f"===== {hf_name} =====")
desc = f"""---
license: apache-2.0
tags:
- tabular-regression
- ehr
- transformer
- medical
model_name: audit-icu-gpt2-25_3M
---
# {hf_name}
This repo contains the model weights for {hf_name}, a tabular language model built on the {model_type} architecture
for evaluating the cross-entropy of Epic EHR audit log event sequences. This model was originally designed to
calculate cross-entropies but can also be used for generation.
The code to train and perform inference this model is available [here]({github_link}).
More details about how to use this model can be found there.
# Model Details
More details can be found in the model card of our paper in Appendix B [here]({arxiv_link}).
Please cite our paper if you use this model in your work:
```
""" + \
"""
@misc{warner2023autoregressive,
title={Autoregressive Language Models For Estimating the Entropy of Epic EHR Audit Logs},
author={Benjamin C. Warner and Thomas Kannampallil and Seunghwan Kim},
year={2023},
eprint={2311.06401},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
""" + \
"""
```
"""
should_push = input(f"Push {model_name} to HF as {hf_name}? (y/n): ").lower() == "y"
if not should_push:
continue
model: PreTrainedModel = types[model_type].from_pretrained(model_path, vocab=vocab)
model.push_to_hub(
repo_id=hf_name,
private=args.debug,
commit_message=f"Uploading {hf_name}"
)
# Add the vocab to the repo
api.upload_file(
path_or_fileobj=os.path.normpath(os.path.join(path_prefix, config["vocab_path"])),
path_in_repo=config["vocab_path"].split(os.sep)[-1],
repo_id=hf_repo,
commit_message=f"Uploading vocab"
)
# Add a brief model card to the repo
api.upload_file(
path_or_fileobj=desc.encode("utf-8"),
path_in_repo="README.md",
repo_id=hf_repo,
commit_message=f"Uploading README"
)