Skip to content

Commit d533c69

Browse files
Merge pull request huggingface#49 from pdoane/annotator-params
Consistent resizing parameters and behavior
2 parents 14bd3a4 + 3540e93 commit d533c69

File tree

16 files changed

+410
-134
lines changed

16 files changed

+410
-134
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ tests/fixtures/cached_*_text.txt
1313
logs/
1414
lightning_logs/
1515
lang_code_data/
16+
tests/outputs
1617

1718
# Distribution / packaging
1819
.Python

src/controlnet_aux/canny/__init__.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,36 @@
1+
import warnings
12
import cv2
23
import numpy as np
34
from PIL import Image
4-
from ..util import HWC3
5+
from ..util import HWC3, resize_image
56

67
class CannyDetector:
7-
def __call__(self, img, low_threshold=100, high_threshold=200):
8+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs):
9+
if "img" in kwargs:
10+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
11+
input_image = kwargs.pop("img")
812

9-
input_type = "np"
10-
if isinstance(img, Image.Image):
11-
img = np.array(img)
12-
input_type = "pil"
13+
if input_image is None:
14+
raise ValueError("input_image must be defined.")
15+
16+
if not isinstance(input_image, np.ndarray):
17+
input_image = np.array(input_image, dtype=np.uint8)
18+
output_type = output_type or "pil"
19+
else:
20+
output_type = output_type or "np"
1321

14-
img = HWC3(img)
15-
img = cv2.Canny(img, low_threshold, high_threshold)
22+
input_image = HWC3(input_image)
23+
input_image = resize_image(input_image, detect_resolution)
24+
25+
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
26+
detected_map = HWC3(detected_map)
27+
28+
img = resize_image(input_image, image_resolution)
29+
H, W, C = img.shape
30+
31+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
1632

17-
if input_type == "pil":
18-
img = Image.fromarray(img)
19-
img = img.convert("RGB")
33+
if output_type == "pil":
34+
detected_map = Image.fromarray(detected_map)
2035

21-
return img
36+
return detected_map

src/controlnet_aux/hed/__init__.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# and in this way it works better for gradio's RGB protocol
77

88
import os
9+
import warnings
910

1011
import cv2
1112
import numpy as np
@@ -78,7 +79,15 @@ def to(self, device):
7879
self.netNetwork.to(device)
7980
return self
8081

81-
def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, return_pil=True, scribble=False):
82+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, **kwargs):
83+
if "return_pil" in kwargs:
84+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
85+
output_type = "pil" if kwargs["return_pil"] else "np"
86+
if type(output_type) is bool:
87+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
88+
if output_type:
89+
output_type = "pil"
90+
8291
device = next(iter(self.netNetwork.parameters())).device
8392
if not isinstance(input_image, np.ndarray):
8493
input_image = np.array(input_image, dtype=np.uint8)
@@ -101,7 +110,6 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
101110
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
102111

103112
detected_map = edge
104-
105113
detected_map = HWC3(detected_map)
106114

107115
img = resize_image(input_image, image_resolution)
@@ -115,7 +123,7 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
115123
detected_map[detected_map > 4] = 255
116124
detected_map[detected_map < 255] = 0
117125

118-
if return_pil:
126+
if output_type == "pil":
119127
detected_map = Image.fromarray(detected_map)
120128

121129
return detected_map

src/controlnet_aux/lineart/__init__.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23

34
import cv2
45
import numpy as np
@@ -122,7 +123,15 @@ def to(self, device):
122123
self.model_coarse.to(device)
123124
return self
124125

125-
def __call__(self, input_image, coarse=False, detect_resolution=512, image_resolution=512, return_pil=True):
126+
def __call__(self, input_image, coarse=False, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
127+
if "return_pil" in kwargs:
128+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
129+
output_type = "pil" if kwargs["return_pil"] else "np"
130+
if type(output_type) is bool:
131+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
132+
if output_type:
133+
output_type = "pil"
134+
126135
device = next(iter(self.model.parameters())).device
127136
if not isinstance(input_image, np.ndarray):
128137
input_image = np.array(input_image, dtype=np.uint8)
@@ -152,7 +161,7 @@ def __call__(self, input_image, coarse=False, detect_resolution=512, image_resol
152161
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
153162
detected_map = 255 - detected_map
154163

155-
if return_pil:
164+
if output_type == "pil":
156165
detected_map = Image.fromarray(detected_map)
157166

158167
return detected_map

src/controlnet_aux/lineart_anime/__init__.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import numpy as np
2-
import torch
3-
import torch.nn as nn
41
import functools
5-
62
import os
3+
import warnings
4+
75
import cv2
6+
import numpy as np
7+
import torch
8+
import torch.nn as nn
89
from einops import rearrange
910
from huggingface_hub import hf_hub_download
1011
from PIL import Image
@@ -141,7 +142,15 @@ def to(self, device):
141142
self.model.to(device)
142143
return self
143144

144-
def __call__(self, input_image, detect_resolution=512, image_resolution=512, return_pil=True):
145+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
146+
if "return_pil" in kwargs:
147+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
148+
output_type = "pil" if kwargs["return_pil"] else "np"
149+
if type(output_type) is bool:
150+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
151+
if output_type:
152+
output_type = "pil"
153+
145154
device = next(iter(self.model.parameters())).device
146155
if not isinstance(input_image, np.ndarray):
147156
input_image = np.array(input_image, dtype=np.uint8)
@@ -174,7 +183,7 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, ret
174183
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
175184
detected_map = 255 - detected_map
176185

177-
if return_pil:
186+
if output_type == "pil":
178187
detected_map = Image.fromarray(detected_map)
179188

180189
return detected_map
+42-12
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,53 @@
1+
import warnings
12
from typing import Union
23

3-
from .mediapipe_face_common import generate_annotation
4-
from PIL import Image
4+
import cv2
55
import numpy as np
6+
from PIL import Image
7+
8+
from ..util import HWC3, resize_image
9+
from .mediapipe_face_common import generate_annotation
610

711

812
class MediapipeFaceDetector:
913
def __call__(self,
10-
image: Union[np.ndarray, Image.Image],
14+
input_image: Union[np.ndarray, Image.Image] = None,
1115
max_faces: int = 1,
1216
min_confidence: float = 0.5,
13-
return_pil: bool = True):
14-
15-
if isinstance(image, Image.Image) is True:
16-
image = np.array(image)
17-
18-
face = generate_annotation(image, max_faces, min_confidence)
17+
output_type: str = "pil",
18+
detect_resolution: int = 512,
19+
image_resolution: int = 512,
20+
**kwargs):
21+
22+
if "image" in kwargs:
23+
warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning)
24+
input_image = kwargs.pop("image")
25+
if input_image is None:
26+
raise ValueError("input_image must be defined.")
27+
28+
if "return_pil" in kwargs:
29+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
30+
output_type = "pil" if kwargs["return_pil"] else "np"
31+
if type(output_type) is bool:
32+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
33+
if output_type:
34+
output_type = "pil"
35+
36+
if not isinstance(input_image, np.ndarray):
37+
input_image = np.array(input_image, dtype=np.uint8)
38+
39+
input_image = HWC3(input_image)
40+
input_image = resize_image(input_image, detect_resolution)
41+
42+
detected_map = generate_annotation(input_image, max_faces, min_confidence)
43+
detected_map = HWC3(detected_map)
44+
45+
img = resize_image(input_image, image_resolution)
46+
H, W, C = img.shape
47+
48+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
1949

20-
if return_pil is True:
21-
face = Image.fromarray(face)
50+
if output_type == "pil":
51+
detected_map = Image.fromarray(detected_map)
2252

23-
return face
53+
return detected_map

src/controlnet_aux/midas/__init__.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from huggingface_hub import hf_hub_download
88
from PIL import Image
99

10-
from ..util import HWC3
10+
from ..util import HWC3, resize_image
1111
from .api import MiDaSInference
1212

1313

@@ -36,14 +36,17 @@ def to(self, device):
3636
self.model.to(device)
3737
return self
3838

39-
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False):
39+
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, image_resolution=512, output_type=None):
4040
device = next(iter(self.model.parameters())).device
41-
input_type = "np"
42-
if isinstance(input_image, Image.Image):
43-
input_image = np.array(input_image)
44-
input_type = "pil"
45-
41+
if not isinstance(input_image, np.ndarray):
42+
input_image = np.array(input_image, dtype=np.uint8)
43+
output_type = output_type or "pil"
44+
else:
45+
output_type = output_type or "np"
46+
4647
input_image = HWC3(input_image)
48+
input_image = resize_image(input_image, detect_resolution)
49+
4750
assert input_image.ndim == 3
4851
image_depth = input_image
4952
with torch.no_grad():
@@ -70,9 +73,19 @@ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False
7073
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
7174
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1]
7275

73-
if input_type == "pil":
76+
depth_image = HWC3(depth_image)
77+
if depth_and_normal:
78+
normal_image = HWC3(normal_image)
79+
80+
img = resize_image(input_image, image_resolution)
81+
H, W, C = img.shape
82+
83+
depth_image = cv2.resize(depth_image, (W, H), interpolation=cv2.INTER_LINEAR)
84+
if depth_and_normal:
85+
normal_image = cv2.resize(normal_image, (W, H), interpolation=cv2.INTER_LINEAR)
86+
87+
if output_type == "pil":
7488
depth_image = Image.fromarray(depth_image)
75-
depth_image = depth_image.convert("RGB")
7689
if depth_and_normal:
7790
normal_image = Image.fromarray(normal_image)
7891

src/controlnet_aux/mlsd/__init__.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
2+
import warnings
23

34
import cv2
45
import numpy as np
56
import torch
6-
from einops import rearrange
77
from huggingface_hub import hf_hub_download
88
from PIL import Image
99

@@ -38,7 +38,15 @@ def to(self, device):
3838
self.model.to(device)
3939
return self
4040

41-
def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, image_resolution=512, return_pil=True):
41+
def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
42+
if "return_pil" in kwargs:
43+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
44+
output_type = "pil" if kwargs["return_pil"] else "np"
45+
if type(output_type) is bool:
46+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
47+
if output_type:
48+
output_type = "pil"
49+
4250
if not isinstance(input_image, np.ndarray):
4351
input_image = np.array(input_image, dtype=np.uint8)
4452

@@ -58,14 +66,14 @@ def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, ima
5866
pass
5967

6068
detected_map = img_output[:, :, 0]
61-
6269
detected_map = HWC3(detected_map)
70+
6371
img = resize_image(input_image, image_resolution)
6472
H, W, C = img.shape
6573

66-
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
74+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
6775

68-
if return_pil:
76+
if output_type == "pil":
6977
detected_map = Image.fromarray(detected_map)
7078

7179
return detected_map

src/controlnet_aux/normalbae/__init__.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import types
3+
import warnings
34

5+
import cv2
46
import numpy as np
57
import torch
68
import torchvision.transforms as transforms
@@ -58,7 +60,15 @@ def to(self, device):
5860
return self
5961

6062

61-
def __call__(self, input_image, detect_resolution=512, image_resolution=512, return_pil=True):
63+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
64+
if "return_pil" in kwargs:
65+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
66+
output_type = "pil" if kwargs["return_pil"] else "np"
67+
if type(output_type) is bool:
68+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
69+
if output_type:
70+
output_type = "pil"
71+
6272
device = next(iter(self.model.parameters())).device
6373
if not isinstance(input_image, np.ndarray):
6474
input_image = np.array(input_image, dtype=np.uint8)
@@ -84,10 +94,16 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, ret
8494
normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
8595
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
8696

87-
img = resize_image(normal_image, image_resolution)
97+
detected_map = normal_image
98+
detected_map = HWC3(detected_map)
99+
100+
img = resize_image(input_image, image_resolution)
101+
H, W, C = img.shape
102+
103+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
88104

89-
if return_pil:
90-
img = Image.fromarray(img)
105+
if output_type == "pil":
106+
detected_map = Image.fromarray(detected_map)
91107

92-
return img
108+
return detected_map
93109

0 commit comments

Comments
 (0)