Skip to content

Commit 8958bb7

Browse files
Merge pull request huggingface#57 from pdoane/leres
Add Leres depth estimator
2 parents d533c69 + c7236c3 commit 8958bb7

27 files changed

+3146
-35
lines changed

README.md

+10-6
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512))
3030

3131
# load processor from processor_id
3232
# options are:
33-
# ["canny", "depth_midas", "depth_zoe", "lineart_anime", "lineart_coarse", "lineart_realistic",
34-
# "mediapipe_face", "mlsd", "normal_bae", "normal_midas", "openpose", "openpose_face", "openpose_faceonly",
35-
# "openpose_full", "openpose_hand", "scribble_hed, "scribble_pidinet", "shuffle", "softedge_hed",
36-
# "softedge_hedsafe", "softedge_pidinet", "softedge_pidsafe"]
33+
# ["canny", "depth_leres", "depth_leres++", "depth_midas", "depth_zoe", "lineart_anime",
34+
# "lineart_coarse", "lineart_realistic", "mediapipe_face", "mlsd", "normal_bae", "normal_midas",
35+
# "openpose", "openpose_face", "openpose_faceonly", "openpose_full", "openpose_hand",
36+
# "scribble_hed, "scribble_pidinet", "shuffle", "softedge_hed", "softedge_hedsafe",
37+
# "softedge_pidinet", "softedge_pidsafe"]
3738
processor_id = 'scribble_hed'
3839
processor = Processor(processor_id)
3940

@@ -45,7 +46,7 @@ Each model can be loaded individually by importing and instantiating them as fol
4546
from PIL import Image
4647
import requests
4748
from io import BytesIO
48-
from controlnet_aux import HEDdetector, MidasDetector, MLSDdetector, OpenposeDetector, PidiNetDetector, NormalBaeDetector, LineartDetector, LineartAnimeDetector, CannyDetector, ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, SamDetector
49+
from controlnet_aux import HEDdetector, MidasDetector, MLSDdetector, OpenposeDetector, PidiNetDetector, NormalBaeDetector, LineartDetector, LineartAnimeDetector, CannyDetector, ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, SamDetector, LeresDetector
4950

5051
# load image
5152
url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png"
@@ -63,7 +64,8 @@ normal_bae = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
6364
lineart = LineartDetector.from_pretrained("lllyasviel/Annotators")
6465
lineart_anime = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
6566
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
66-
sam = SamDetector.from_pretrained("./weight_path")
67+
sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
68+
leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
6769

6870
# instantiate
6971
canny = CannyDetector()
@@ -81,6 +83,8 @@ processed_image_normal_bae = normal_bae(img)
8183
processed_image_lineart = lineart(img, coarse=True)
8284
processed_image_lineart_anime = lineart_anime(img)
8385
processed_image_zoe = zoe(img)
86+
processed_image_sam = sam(img)
87+
processed_image_leres = leres(img)
8488

8589
processed_image_canny = canny(img)
8690
processed_image_content = content(img)

src/controlnet_aux/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
__version__ = "0.0.5"
22

33
from .hed import HEDdetector
4+
from .leres import LeresDetector
5+
from .lineart import LineartDetector
6+
from .lineart_anime import LineartAnimeDetector
47
from .midas import MidasDetector
58
from .mlsd import MLSDdetector
9+
from .normalbae import NormalBaeDetector
610
from .open_pose import OpenposeDetector
711
from .pidi import PidiNetDetector
8-
from .normalbae import NormalBaeDetector
9-
from .lineart import LineartDetector
10-
from .lineart_anime import LineartAnimeDetector
1112
from .zoe import ZoeDetector
1213

1314
from .canny import CannyDetector
14-
from .shuffle import ContentShuffleDetector
1515
from .mediapipe_face import MediapipeFaceDetector
1616
from .segment_anything import SamDetector
17+
from .shuffle import ContentShuffleDetector

src/controlnet_aux/leres/__init__.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
3+
import cv2
4+
import numpy as np
5+
import torch
6+
from huggingface_hub import hf_hub_download
7+
from PIL import Image
8+
9+
from ..util import HWC3, resize_image
10+
from .leres.depthmap import estimateboost, estimateleres
11+
from .leres.multi_depth_model_woauxi import RelDepthModel
12+
from .leres.net_tools import strip_prefix_if_present
13+
from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
14+
from .pix2pix.options.test_options import TestOptions
15+
16+
17+
class LeresDetector:
18+
def __init__(self, model, pix2pixmodel):
19+
self.model = model
20+
self.pix2pixmodel = pix2pixmodel
21+
22+
@classmethod
23+
def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None):
24+
filename = filename or "res101.pth"
25+
pix2pix_filename = pix2pix_filename or "latest_net_G.pth"
26+
27+
if os.path.isdir(pretrained_model_or_path):
28+
model_path = os.path.join(pretrained_model_or_path, filename)
29+
else:
30+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
31+
32+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
33+
34+
model = RelDepthModel(backbone='resnext101')
35+
model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
36+
del checkpoint
37+
38+
if os.path.isdir(pretrained_model_or_path):
39+
model_path = os.path.join(pretrained_model_or_path, pix2pix_filename)
40+
else:
41+
model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir)
42+
43+
opt = TestOptions().parse()
44+
if not torch.cuda.is_available():
45+
opt.gpu_ids = [] # cpu mode
46+
pix2pixmodel = Pix2Pix4DepthModel(opt)
47+
pix2pixmodel.save_dir = os.path.dirname(model_path)
48+
pix2pixmodel.load_networks('latest')
49+
pix2pixmodel.eval()
50+
51+
return cls(model, pix2pixmodel)
52+
53+
def to(self, device):
54+
self.model.to(device)
55+
# TODO - refactor pix2pix implementation to support device migration
56+
# self.pix2pixmodel.to(device)
57+
return self
58+
59+
def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"):
60+
device = next(iter(self.model.parameters())).device
61+
if not isinstance(input_image, np.ndarray):
62+
input_image = np.array(input_image, dtype=np.uint8)
63+
64+
input_image = HWC3(input_image)
65+
input_image = resize_image(input_image, detect_resolution)
66+
67+
assert input_image.ndim == 3
68+
height, width, dim = input_image.shape
69+
70+
with torch.no_grad():
71+
72+
if boost:
73+
depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height))
74+
else:
75+
depth = estimateleres(input_image, self.model, width, height)
76+
77+
numbytes=2
78+
depth_min = depth.min()
79+
depth_max = depth.max()
80+
max_val = (2**(8*numbytes))-1
81+
82+
# check output before normalizing and mapping to 16 bit
83+
if depth_max - depth_min > np.finfo("float").eps:
84+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
85+
else:
86+
out = np.zeros(depth.shape)
87+
88+
# single channel, 16 bit image
89+
depth_image = out.astype("uint16")
90+
91+
# convert to uint8
92+
depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
93+
94+
# remove near
95+
if thr_a != 0:
96+
thr_a = ((thr_a/100)*255)
97+
depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
98+
99+
# invert image
100+
depth_image = cv2.bitwise_not(depth_image)
101+
102+
# remove bg
103+
if thr_b != 0:
104+
thr_b = ((thr_b/100)*255)
105+
depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
106+
107+
detected_map = depth_image
108+
detected_map = HWC3(detected_map)
109+
110+
img = resize_image(input_image, image_resolution)
111+
H, W, C = img.shape
112+
113+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
114+
115+
if output_type == "pil":
116+
detected_map = Image.fromarray(detected_map)
117+
118+
return detected_map
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
https://github.com/thygate/stable-diffusion-webui-depthmap-script
2+
3+
MIT License
4+
5+
Copyright (c) 2023 Bob Thiry
6+
7+
Permission is hereby granted, free of charge, to any person obtaining a copy
8+
of this software and associated documentation files (the "Software"), to deal
9+
in the Software without restriction, including without limitation the rights
10+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
copies of the Software, and to permit persons to whom the Software is
12+
furnished to do so, subject to the following conditions:
13+
14+
The above copyright notice and this permission notice shall be included in all
15+
copies or substantial portions of the Software.
16+
17+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
SOFTWARE.

0 commit comments

Comments
 (0)