Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Sep 3, 2023
1 parent 351d5e3 commit a16c7d3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/controlnet_aux/zoe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ def __init__(self, model):
self.model = model

@classmethod
def from_pretrained(cls, pretrained_model_or_path, model_type="zoe", filename=None, cache_dir=None):
def from_pretrained(cls, pretrained_model_or_path, model_type="zoedepth", filename=None, cache_dir=None):
filename = filename or "ZoeD_M12_N.pt"

if os.path.isdir(pretrained_model_or_path):
model_path = os.path.join(pretrained_model_or_path, filename)
else:
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)

conf = get_config("zoedepth", "infer")
model_cls = ZoeDepth if model_type == "zoe" else ZoeDepthNK
conf = get_config(model_type, "infer")
model_cls = ZoeDepth if model_type == "zoedepth" else ZoeDepthNK
model = model_cls.build_from_config(conf)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
model.eval()
Expand Down

0 comments on commit a16c7d3

Please sign in to comment.