Skip to content

Commit

Permalink
Update to take HuggingFace path
Browse files Browse the repository at this point in the history
  • Loading branch information
Theodore Zhao committed Nov 27, 2024
1 parent 6f731d0 commit 6fa21ec
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
opt = init_distributed(opt)

# Load model from pretrained weights
pretrained_pth = 'pretrained/biomed_parse.pt'
#pretrained_pth = 'pretrained/biomed_parse.pt'
pretrained_pth = 'hf_hub:microsoft/BiomedParse'

model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
with torch.no_grad():
Expand Down
1 change: 1 addition & 0 deletions example_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Load model from pretrained weights
pretrained_pth = 'pretrained/biomed_parse.pt'
pretrained_pth = 'hf_hub:microsoft/BiomedParse'

model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
with torch.no_grad():
Expand Down
18 changes: 17 additions & 1 deletion modeling/BaseModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from utilities.model import align_and_update_state_dicts

from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files

import huggingface_hub

logger = logging.getLogger(__name__)


Expand All @@ -22,7 +27,18 @@ def forward(self, *inputs, **kwargs):
def save_pretrained(self, save_dir):
torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))

def from_pretrained(self, load_dir):
def from_pretrained(self, pretrained,
local_dir: str = "./pretrained", config_dir: str = "./configs"):
if pretrained.startswith("hf_hub:"):
hub_name = pretrained.split(":")[1]
huggingface_hub.hf_hub_download(hub_name, filename="biomedparse_v1.pt",
local_dir=local_dir)
huggingface_hub.hf_hub_download(hub_name, filename="config.yaml",
local_dir=config_dir)
load_dir = os.path.join(local_dir, "biomedparse_v1.pt")
else:
load_dir = pretrained

state_dict = torch.load(load_dir, map_location=self.opt['device'])
state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
self.model.load_state_dict(state_dict, strict=False)
Expand Down

0 comments on commit 6fa21ec

Please sign in to comment.