Skip to content

Commit

Permalink
download from huggingfacehub
Browse files Browse the repository at this point in the history
  • Loading branch information
jinwonkim93 committed Mar 8, 2023
1 parent ae759fb commit 68aa765
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDet
open_pose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
hed = HEDdetector.from_pretrained("lllyasviel/ControlNet")
midas = MidasDetector.from_pretrained("lllyasviel/ControlNet")
canny = CannyDetector()
midas = MidasDetector()

```
12 changes: 10 additions & 2 deletions src/controlnet_aux/midas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
from PIL import Image
import torch

from huggingface_hub import hf_hub_download
from einops import rearrange
from .api import MiDaSInference


class MidasDetector:
def __init__(self):
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
def __init__(self, model_type="dpt_hybrid", model_path=None):
self.model = MiDaSInference(model_type=model_type, model_path=model_path).cuda()


@classmethod
def from_pretrained(cls, pretrained_model_or_path, model_type="dpt_hybrid", filename=None):
filename = filename or "annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
model_path = hf_hub_download(pretrained_model_or_path, filename)
return cls(model_type=model_type, model_path=model_path)

def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):

input_type = "np"
Expand Down
8 changes: 4 additions & 4 deletions src/controlnet_aux/midas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def load_midas_transform(model_type):
return transform


def load_model(model_type):
def load_model(model_type, model_path=None):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path = ISL_PATHS[model_type]
model_path = model_path or ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
model = DPTDepthModel(
path=model_path,
Expand Down Expand Up @@ -155,10 +155,10 @@ class MiDaSInference(nn.Module):
"midas_v21_small",
]

def __init__(self, model_type):
def __init__(self, model_type, model_path):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type)
model, _ = load_model(model_type, model_path)
self.model = model
self.model.train = disabled_train

Expand Down

0 comments on commit 68aa765

Please sign in to comment.