Skip to content

Commit de9334a

Browse files
Merge pull request huggingface#30 from CrazyBoyM/master
feat: support for load weights from local dir
2 parents bff7822 + 9931ffe commit de9334a

File tree

9 files changed

+47
-12
lines changed

9 files changed

+47
-12
lines changed

src/controlnet_aux/hed/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None
106106
else:
107107
filename = filename or "network-bsds500.pth"
108108

109-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
109+
if os.path.isdir(pretrained_model_or_path):
110+
model_path = os.path.join(pretrained_model_or_path, filename)
111+
else:
112+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
110113

111114
netNetwork = Network(model_path)
112115

src/controlnet_aux/lineart/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,12 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, coarse_filenam
104104
filename = filename or "sk_model.pth"
105105
coarse_filename = coarse_filename or "sk_model2.pth"
106106

107-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
108-
coarse_model_path = hf_hub_download(pretrained_model_or_path, coarse_filename, cache_dir=cache_dir)
107+
if os.path.isdir(pretrained_model_or_path):
108+
model_path = os.path.join(pretrained_model_or_path, filename)
109+
coarse_model_path = os.path.join(pretrained_model_or_path, coarse_filename)
110+
else:
111+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
112+
coarse_model_path = hf_hub_download(pretrained_model_or_path, coarse_filename, cache_dir=cache_dir)
109113

110114
model = Generator(3, 1, 3)
111115
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

src/controlnet_aux/lineart_anime/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ def __init__(self, model):
122122
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
123123
filename = filename or "netG.pth"
124124

125-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
125+
if os.path.isdir(pretrained_model_or_path):
126+
model_path = os.path.join(pretrained_model_or_path, filename)
127+
else:
128+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
126129

127130
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
128131
net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)

src/controlnet_aux/midas/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from PIL import Image
44
import torch
5+
import os
56

67
from huggingface_hub import hf_hub_download
78
from einops import rearrange
@@ -22,7 +23,11 @@ def from_pretrained(cls, pretrained_model_or_path, model_type="dpt_hybrid", file
2223
else:
2324
filename = filename or "dpt_hybrid-midas-501f0c75.pt"
2425

25-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
26+
if os.path.isdir(pretrained_model_or_path):
27+
model_path = os.path.join(pretrained_model_or_path, filename)
28+
else:
29+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
30+
2631
return cls(model_type=model_type, model_path=model_path)
2732

2833
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False):

src/controlnet_aux/mlsd/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None
2222
else:
2323
filename = filename or "mlsd_large_512_fp32.pth"
2424

25-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
25+
if os.path.isdir(pretrained_model_or_path):
26+
model_path = os.path.join(pretrained_model_or_path, filename)
27+
else:
28+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
2629

2730
model = MobileV2_MLSD_Large()
2831
model.load_state_dict(torch.load(model_path), strict=True)

src/controlnet_aux/normalbae/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ def __init__(self, model):
2424
@classmethod
2525
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
2626
filename = filename or "scannet.pt"
27-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
27+
28+
if os.path.isdir(pretrained_model_or_path):
29+
model_path = os.path.join(pretrained_model_or_path, filename)
30+
else:
31+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
2832

2933
args = types.SimpleNamespace()
3034
args.mode = 'client'

src/controlnet_aux/open_pose/__init__.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,14 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=
6060

6161
face_pretrained_model_or_path = pretrained_model_or_path
6262

63-
body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
64-
hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir)
65-
face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir)
63+
if os.path.isdir(pretrained_model_or_path):
64+
body_model_path = os.path.join(pretrained_model_or_path, filename)
65+
hand_model_path = os.path.join(pretrained_model_or_path, hand_filename)
66+
face_model_path = os.path.join(face_pretrained_model_or_path, face_filename)
67+
else:
68+
body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
69+
hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir)
70+
face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir)
6671

6772
body_estimation = Body(body_model_path)
6873
hand_estimation = Hand(hand_model_path)

src/controlnet_aux/pidi/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ def __init__(self, netNetwork):
1717
@classmethod
1818
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
1919
filename = filename or "table5_pidinet.pth"
20-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
20+
21+
if os.path.isdir(pretrained_model_or_path):
22+
model_path = os.path.join(pretrained_model_or_path, filename)
23+
else:
24+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
2125

2226
netNetwork = pidinet()
2327
netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()})

src/controlnet_aux/zoe/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ def __init__(self, model):
2121
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
2222
filename = filename or "ZoeD_M12_N.pt"
2323

24-
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
24+
if os.path.isdir(pretrained_model_or_path):
25+
model_path = os.path.join(pretrained_model_or_path, filename)
26+
else:
27+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
28+
2529
conf = get_config("zoedepth", "infer")
2630
model = ZoeDepth.build_from_config(conf)
2731
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])

0 commit comments

Comments
 (0)