Skip to content

Commit 845161d

Browse files
Merge pull request huggingface#48 from pdoane/annotator-updates
Update annotators
2 parents 9afb532 + 9fd18ba commit 845161d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1268
-1662
lines changed
Binary file not shown.

src/controlnet_aux/hed/__init__.py

+66-119
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,83 @@
1-
import numpy as np
2-
import cv2
1+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2+
# Please use this implementation in your products
3+
# This implementation may produce slightly different results from Saining Xie's official implementations,
4+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6+
# and in this way it works better for gradio's RGB protocol
7+
38
import os
9+
10+
import cv2
11+
import numpy as np
412
import torch
513
from einops import rearrange
614
from huggingface_hub import hf_hub_download
715
from PIL import Image
8-
from ..open_pose.util import HWC3, resize_image
9-
from ..util import safe_step
1016

11-
class Network(torch.nn.Module):
12-
def __init__(self, model_path):
13-
super().__init__()
17+
from ..util import HWC3, nms, resize_image, safe_step
1418

15-
self.netVggOne = torch.nn.Sequential(
16-
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
17-
torch.nn.ReLU(inplace=False),
18-
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
19-
torch.nn.ReLU(inplace=False)
20-
)
21-
22-
self.netVggTwo = torch.nn.Sequential(
23-
torch.nn.MaxPool2d(kernel_size=2, stride=2),
24-
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
25-
torch.nn.ReLU(inplace=False),
26-
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
27-
torch.nn.ReLU(inplace=False)
28-
)
29-
30-
self.netVggThr = torch.nn.Sequential(
31-
torch.nn.MaxPool2d(kernel_size=2, stride=2),
32-
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
33-
torch.nn.ReLU(inplace=False),
34-
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
35-
torch.nn.ReLU(inplace=False),
36-
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
37-
torch.nn.ReLU(inplace=False)
38-
)
39-
40-
self.netVggFou = torch.nn.Sequential(
41-
torch.nn.MaxPool2d(kernel_size=2, stride=2),
42-
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
43-
torch.nn.ReLU(inplace=False),
44-
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
45-
torch.nn.ReLU(inplace=False),
46-
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
47-
torch.nn.ReLU(inplace=False)
48-
)
49-
50-
self.netVggFiv = torch.nn.Sequential(
51-
torch.nn.MaxPool2d(kernel_size=2, stride=2),
52-
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
53-
torch.nn.ReLU(inplace=False),
54-
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
55-
torch.nn.ReLU(inplace=False),
56-
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
57-
torch.nn.ReLU(inplace=False)
58-
)
59-
60-
self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
61-
self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
62-
self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
63-
self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
64-
self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
65-
66-
self.netCombine = torch.nn.Sequential(
67-
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
68-
torch.nn.Sigmoid()
69-
)
70-
71-
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
72-
73-
def forward(self, tenInput):
74-
tenInput = tenInput * 255.0
75-
tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
76-
77-
tenVggOne = self.netVggOne(tenInput)
78-
tenVggTwo = self.netVggTwo(tenVggOne)
79-
tenVggThr = self.netVggThr(tenVggTwo)
80-
tenVggFou = self.netVggFou(tenVggThr)
81-
tenVggFiv = self.netVggFiv(tenVggFou)
82-
83-
tenScoreOne = self.netScoreOne(tenVggOne)
84-
tenScoreTwo = self.netScoreTwo(tenVggTwo)
85-
tenScoreThr = self.netScoreThr(tenVggThr)
86-
tenScoreFou = self.netScoreFou(tenVggFou)
87-
tenScoreFiv = self.netScoreFiv(tenVggFiv)
88-
89-
tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
90-
tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
91-
tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
92-
tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
93-
tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
94-
95-
return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
9619

20+
class DoubleConvBlock(torch.nn.Module):
21+
def __init__(self, input_channel, output_channel, layer_number):
22+
super().__init__()
23+
self.convs = torch.nn.Sequential()
24+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
25+
for i in range(1, layer_number):
26+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
27+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
28+
29+
def __call__(self, x, down_sampling=False):
30+
h = x
31+
if down_sampling:
32+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
33+
for conv in self.convs:
34+
h = conv(h)
35+
h = torch.nn.functional.relu(h)
36+
return h, self.projection(h)
37+
38+
39+
class ControlNetHED_Apache2(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
43+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
44+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
45+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
46+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
47+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
48+
49+
def __call__(self, x):
50+
h = x - self.norm
51+
h, projection1 = self.block1(h)
52+
h, projection2 = self.block2(h, down_sampling=True)
53+
h, projection3 = self.block3(h, down_sampling=True)
54+
h, projection4 = self.block4(h, down_sampling=True)
55+
h, projection5 = self.block5(h, down_sampling=True)
56+
return projection1, projection2, projection3, projection4, projection5
9757

9858
class HEDdetector:
9959
def __init__(self, netNetwork):
100-
self.netNetwork = netNetwork.eval()
60+
self.netNetwork = netNetwork
10161

10262
@classmethod
10363
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
104-
if pretrained_model_or_path == "lllyasviel/ControlNet":
105-
filename = filename or "annotator/ckpts/network-bsds500.pth"
106-
else:
107-
filename = filename or "network-bsds500.pth"
64+
filename = filename or "ControlNetHED.pth"
10865

10966
if os.path.isdir(pretrained_model_or_path):
11067
model_path = os.path.join(pretrained_model_or_path, filename)
11168
else:
11269
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
11370

114-
netNetwork = Network(model_path)
71+
netNetwork = ControlNetHED_Apache2()
72+
netNetwork.load_state_dict(torch.load(model_path, map_location='cpu'))
73+
netNetwork.float().eval()
11574

11675
return cls(netNetwork)
117-
76+
77+
def to(self, device):
78+
self.netNetwork.to(device)
79+
return self
80+
11881
def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, return_pil=True, scribble=False):
11982
device = next(iter(self.netNetwork.parameters())).device
12083
if not isinstance(input_image, np.ndarray):
@@ -124,19 +87,20 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
12487
input_image = resize_image(input_image, detect_resolution)
12588

12689
assert input_image.ndim == 3
127-
input_image = input_image[:, :, ::-1].copy()
90+
H, W, C = input_image.shape
12891
with torch.no_grad():
129-
image_hed = torch.from_numpy(input_image).float()
130-
image_hed = image_hed.to(device)
131-
image_hed = image_hed / 255.0
92+
image_hed = torch.from_numpy(input_image.copy()).float().to(device)
13293
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
133-
edge = self.netNetwork(image_hed)[0]
134-
edge = edge.cpu().numpy()
94+
edges = self.netNetwork(image_hed)
95+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
96+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
97+
edges = np.stack(edges, axis=2)
98+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
13599
if safe:
136100
edge = safe_step(edge)
137101
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
138102

139-
detected_map = edge[0]
103+
detected_map = edge
140104

141105
detected_map = HWC3(detected_map)
142106

@@ -155,20 +119,3 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
155119
detected_map = Image.fromarray(detected_map)
156120

157121
return detected_map
158-
159-
def nms(x, t, s):
160-
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
161-
162-
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
163-
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
164-
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
165-
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
166-
167-
y = np.zeros_like(x)
168-
169-
for f in [f1, f2, f3, f4]:
170-
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
171-
172-
z = np.zeros_like(y, dtype=np.uint8)
173-
z[y > t] = 255
174-
return z
Binary file not shown.

src/controlnet_aux/lineart/LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 Caroline Chan
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

src/controlnet_aux/lineart/__init__.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
2+
23
import cv2
3-
import torch
44
import numpy as np
5-
6-
from huggingface_hub import hf_hub_download
5+
import torch
76
import torch.nn as nn
87
from einops import rearrange
8+
from huggingface_hub import hf_hub_download
99
from PIL import Image
10-
from ..open_pose.util import HWC3, resize_image
1110

11+
from ..util import HWC3, resize_image
1212

1313
norm_layer = nn.InstanceNorm2d
1414

@@ -92,12 +92,8 @@ def forward(self, x, cond=None):
9292

9393
class LineartDetector:
9494
def __init__(self, model, coarse_model):
95-
self.model = model.eval()
96-
self.model_coarse = coarse_model.eval()
97-
98-
if torch.cuda.is_available():
99-
self.model.cuda()
100-
self.model_coarse.cuda()
95+
self.model = model
96+
self.model_coarse = coarse_model
10197

10298
@classmethod
10399
def from_pretrained(cls, pretrained_model_or_path, filename=None, coarse_filename=None, cache_dir=None):
@@ -113,12 +109,19 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, coarse_filenam
113109

114110
model = Generator(3, 1, 3)
115111
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
112+
model.eval()
116113

117114
coarse_model = Generator(3, 1, 3)
118115
coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu')))
116+
coarse_model.eval()
119117

120118
return cls(model, coarse_model)
121-
119+
120+
def to(self, device):
121+
self.model.to(device)
122+
self.model_coarse.to(device)
123+
return self
124+
122125
def __call__(self, input_image, coarse=False, detect_resolution=512, image_resolution=512, return_pil=True):
123126
device = next(iter(self.model.parameters())).device
124127
if not isinstance(input_image, np.ndarray):
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 Caroline Chan
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

src/controlnet_aux/lineart_anime/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import os
77
import cv2
88
from einops import rearrange
9-
from PIL import Image
10-
from ..open_pose.util import HWC3, resize_image
119
from huggingface_hub import hf_hub_download
10+
from PIL import Image
11+
12+
from ..util import HWC3, resize_image
1213

1314

1415
class UnetGenerator(nn.Module):
@@ -113,10 +114,7 @@ def forward(self, x):
113114

114115
class LineartAnimeDetector:
115116
def __init__(self, model):
116-
self.model = model.eval()
117-
118-
if torch.cuda.is_available():
119-
self.model.cuda()
117+
self.model = model
120118

121119
@classmethod
122120
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
@@ -135,9 +133,14 @@ def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None
135133
ckpt[key.replace('module.', '')] = ckpt[key]
136134
del ckpt[key]
137135
net.load_state_dict(ckpt)
136+
net.eval()
138137

139138
return cls(net)
140139

140+
def to(self, device):
141+
self.model.to(device)
142+
return self
143+
141144
def __call__(self, input_image, detect_resolution=512, image_resolution=512, return_pil=True):
142145
device = next(iter(self.model.parameters())).device
143146
if not isinstance(input_image, np.ndarray):

src/controlnet_aux/midas/LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

0 commit comments

Comments
 (0)