1
- import numpy as np
2
- import cv2
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
3
8
import os
9
+
10
+ import cv2
11
+ import numpy as np
4
12
import torch
5
13
from einops import rearrange
6
14
from huggingface_hub import hf_hub_download
7
15
from PIL import Image
8
- from ..open_pose .util import HWC3 , resize_image
9
- from ..util import safe_step
10
16
11
- class Network (torch .nn .Module ):
12
- def __init__ (self , model_path ):
13
- super ().__init__ ()
17
+ from ..util import HWC3 , nms , resize_image , safe_step
14
18
15
- self .netVggOne = torch .nn .Sequential (
16
- torch .nn .Conv2d (in_channels = 3 , out_channels = 64 , kernel_size = 3 , stride = 1 , padding = 1 ),
17
- torch .nn .ReLU (inplace = False ),
18
- torch .nn .Conv2d (in_channels = 64 , out_channels = 64 , kernel_size = 3 , stride = 1 , padding = 1 ),
19
- torch .nn .ReLU (inplace = False )
20
- )
21
-
22
- self .netVggTwo = torch .nn .Sequential (
23
- torch .nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
24
- torch .nn .Conv2d (in_channels = 64 , out_channels = 128 , kernel_size = 3 , stride = 1 , padding = 1 ),
25
- torch .nn .ReLU (inplace = False ),
26
- torch .nn .Conv2d (in_channels = 128 , out_channels = 128 , kernel_size = 3 , stride = 1 , padding = 1 ),
27
- torch .nn .ReLU (inplace = False )
28
- )
29
-
30
- self .netVggThr = torch .nn .Sequential (
31
- torch .nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
32
- torch .nn .Conv2d (in_channels = 128 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ),
33
- torch .nn .ReLU (inplace = False ),
34
- torch .nn .Conv2d (in_channels = 256 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ),
35
- torch .nn .ReLU (inplace = False ),
36
- torch .nn .Conv2d (in_channels = 256 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ),
37
- torch .nn .ReLU (inplace = False )
38
- )
39
-
40
- self .netVggFou = torch .nn .Sequential (
41
- torch .nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
42
- torch .nn .Conv2d (in_channels = 256 , out_channels = 512 , kernel_size = 3 , stride = 1 , padding = 1 ),
43
- torch .nn .ReLU (inplace = False ),
44
- torch .nn .Conv2d (in_channels = 512 , out_channels = 512 , kernel_size = 3 , stride = 1 , padding = 1 ),
45
- torch .nn .ReLU (inplace = False ),
46
- torch .nn .Conv2d (in_channels = 512 , out_channels = 512 , kernel_size = 3 , stride = 1 , padding = 1 ),
47
- torch .nn .ReLU (inplace = False )
48
- )
49
-
50
- self .netVggFiv = torch .nn .Sequential (
51
- torch .nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
52
- torch .nn .Conv2d (in_channels = 512 , out_channels = 512 , kernel_size = 3 , stride = 1 , padding = 1 ),
53
- torch .nn .ReLU (inplace = False ),
54
- torch .nn .Conv2d (in_channels = 512 , out_channels = 512 , kernel_size = 3 , stride = 1 , padding = 1 ),
55
- torch .nn .ReLU (inplace = False ),
56
- torch .nn .Conv2d (in_channels = 512 , out_channels = 512 , kernel_size = 3 , stride = 1 , padding = 1 ),
57
- torch .nn .ReLU (inplace = False )
58
- )
59
-
60
- self .netScoreOne = torch .nn .Conv2d (in_channels = 64 , out_channels = 1 , kernel_size = 1 , stride = 1 , padding = 0 )
61
- self .netScoreTwo = torch .nn .Conv2d (in_channels = 128 , out_channels = 1 , kernel_size = 1 , stride = 1 , padding = 0 )
62
- self .netScoreThr = torch .nn .Conv2d (in_channels = 256 , out_channels = 1 , kernel_size = 1 , stride = 1 , padding = 0 )
63
- self .netScoreFou = torch .nn .Conv2d (in_channels = 512 , out_channels = 1 , kernel_size = 1 , stride = 1 , padding = 0 )
64
- self .netScoreFiv = torch .nn .Conv2d (in_channels = 512 , out_channels = 1 , kernel_size = 1 , stride = 1 , padding = 0 )
65
-
66
- self .netCombine = torch .nn .Sequential (
67
- torch .nn .Conv2d (in_channels = 5 , out_channels = 1 , kernel_size = 1 , stride = 1 , padding = 0 ),
68
- torch .nn .Sigmoid ()
69
- )
70
-
71
- self .load_state_dict ({strKey .replace ('module' , 'net' ): tenWeight for strKey , tenWeight in torch .load (model_path ).items ()})
72
-
73
- def forward (self , tenInput ):
74
- tenInput = tenInput * 255.0
75
- tenInput = tenInput - torch .tensor (data = [104.00698793 , 116.66876762 , 122.67891434 ], dtype = tenInput .dtype , device = tenInput .device ).view (1 , 3 , 1 , 1 )
76
-
77
- tenVggOne = self .netVggOne (tenInput )
78
- tenVggTwo = self .netVggTwo (tenVggOne )
79
- tenVggThr = self .netVggThr (tenVggTwo )
80
- tenVggFou = self .netVggFou (tenVggThr )
81
- tenVggFiv = self .netVggFiv (tenVggFou )
82
-
83
- tenScoreOne = self .netScoreOne (tenVggOne )
84
- tenScoreTwo = self .netScoreTwo (tenVggTwo )
85
- tenScoreThr = self .netScoreThr (tenVggThr )
86
- tenScoreFou = self .netScoreFou (tenVggFou )
87
- tenScoreFiv = self .netScoreFiv (tenVggFiv )
88
-
89
- tenScoreOne = torch .nn .functional .interpolate (input = tenScoreOne , size = (tenInput .shape [2 ], tenInput .shape [3 ]), mode = 'bilinear' , align_corners = False )
90
- tenScoreTwo = torch .nn .functional .interpolate (input = tenScoreTwo , size = (tenInput .shape [2 ], tenInput .shape [3 ]), mode = 'bilinear' , align_corners = False )
91
- tenScoreThr = torch .nn .functional .interpolate (input = tenScoreThr , size = (tenInput .shape [2 ], tenInput .shape [3 ]), mode = 'bilinear' , align_corners = False )
92
- tenScoreFou = torch .nn .functional .interpolate (input = tenScoreFou , size = (tenInput .shape [2 ], tenInput .shape [3 ]), mode = 'bilinear' , align_corners = False )
93
- tenScoreFiv = torch .nn .functional .interpolate (input = tenScoreFiv , size = (tenInput .shape [2 ], tenInput .shape [3 ]), mode = 'bilinear' , align_corners = False )
94
-
95
- return self .netCombine (torch .cat ([ tenScoreOne , tenScoreTwo , tenScoreThr , tenScoreFou , tenScoreFiv ], 1 ))
96
19
20
+ class DoubleConvBlock (torch .nn .Module ):
21
+ def __init__ (self , input_channel , output_channel , layer_number ):
22
+ super ().__init__ ()
23
+ self .convs = torch .nn .Sequential ()
24
+ self .convs .append (torch .nn .Conv2d (in_channels = input_channel , out_channels = output_channel , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 1 ))
25
+ for i in range (1 , layer_number ):
26
+ self .convs .append (torch .nn .Conv2d (in_channels = output_channel , out_channels = output_channel , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 1 ))
27
+ self .projection = torch .nn .Conv2d (in_channels = output_channel , out_channels = 1 , kernel_size = (1 , 1 ), stride = (1 , 1 ), padding = 0 )
28
+
29
+ def __call__ (self , x , down_sampling = False ):
30
+ h = x
31
+ if down_sampling :
32
+ h = torch .nn .functional .max_pool2d (h , kernel_size = (2 , 2 ), stride = (2 , 2 ))
33
+ for conv in self .convs :
34
+ h = conv (h )
35
+ h = torch .nn .functional .relu (h )
36
+ return h , self .projection (h )
37
+
38
+
39
+ class ControlNetHED_Apache2 (torch .nn .Module ):
40
+ def __init__ (self ):
41
+ super ().__init__ ()
42
+ self .norm = torch .nn .Parameter (torch .zeros (size = (1 , 3 , 1 , 1 )))
43
+ self .block1 = DoubleConvBlock (input_channel = 3 , output_channel = 64 , layer_number = 2 )
44
+ self .block2 = DoubleConvBlock (input_channel = 64 , output_channel = 128 , layer_number = 2 )
45
+ self .block3 = DoubleConvBlock (input_channel = 128 , output_channel = 256 , layer_number = 3 )
46
+ self .block4 = DoubleConvBlock (input_channel = 256 , output_channel = 512 , layer_number = 3 )
47
+ self .block5 = DoubleConvBlock (input_channel = 512 , output_channel = 512 , layer_number = 3 )
48
+
49
+ def __call__ (self , x ):
50
+ h = x - self .norm
51
+ h , projection1 = self .block1 (h )
52
+ h , projection2 = self .block2 (h , down_sampling = True )
53
+ h , projection3 = self .block3 (h , down_sampling = True )
54
+ h , projection4 = self .block4 (h , down_sampling = True )
55
+ h , projection5 = self .block5 (h , down_sampling = True )
56
+ return projection1 , projection2 , projection3 , projection4 , projection5
97
57
98
58
class HEDdetector :
99
59
def __init__ (self , netNetwork ):
100
- self .netNetwork = netNetwork . eval ()
60
+ self .netNetwork = netNetwork
101
61
102
62
@classmethod
103
63
def from_pretrained (cls , pretrained_model_or_path , filename = None , cache_dir = None ):
104
- if pretrained_model_or_path == "lllyasviel/ControlNet" :
105
- filename = filename or "annotator/ckpts/network-bsds500.pth"
106
- else :
107
- filename = filename or "network-bsds500.pth"
64
+ filename = filename or "ControlNetHED.pth"
108
65
109
66
if os .path .isdir (pretrained_model_or_path ):
110
67
model_path = os .path .join (pretrained_model_or_path , filename )
111
68
else :
112
69
model_path = hf_hub_download (pretrained_model_or_path , filename , cache_dir = cache_dir )
113
70
114
- netNetwork = Network (model_path )
71
+ netNetwork = ControlNetHED_Apache2 ()
72
+ netNetwork .load_state_dict (torch .load (model_path , map_location = 'cpu' ))
73
+ netNetwork .float ().eval ()
115
74
116
75
return cls (netNetwork )
117
-
76
+
77
+ def to (self , device ):
78
+ self .netNetwork .to (device )
79
+ return self
80
+
118
81
def __call__ (self , input_image , detect_resolution = 512 , image_resolution = 512 , safe = False , return_pil = True , scribble = False ):
119
82
device = next (iter (self .netNetwork .parameters ())).device
120
83
if not isinstance (input_image , np .ndarray ):
@@ -124,19 +87,20 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
124
87
input_image = resize_image (input_image , detect_resolution )
125
88
126
89
assert input_image .ndim == 3
127
- input_image = input_image [:, :, :: - 1 ]. copy ()
90
+ H , W , C = input_image . shape
128
91
with torch .no_grad ():
129
- image_hed = torch .from_numpy (input_image ).float ()
130
- image_hed = image_hed .to (device )
131
- image_hed = image_hed / 255.0
92
+ image_hed = torch .from_numpy (input_image .copy ()).float ().to (device )
132
93
image_hed = rearrange (image_hed , 'h w c -> 1 c h w' )
133
- edge = self .netNetwork (image_hed )[0 ]
134
- edge = edge .cpu ().numpy ()
94
+ edges = self .netNetwork (image_hed )
95
+ edges = [e .detach ().cpu ().numpy ().astype (np .float32 )[0 , 0 ] for e in edges ]
96
+ edges = [cv2 .resize (e , (W , H ), interpolation = cv2 .INTER_LINEAR ) for e in edges ]
97
+ edges = np .stack (edges , axis = 2 )
98
+ edge = 1 / (1 + np .exp (- np .mean (edges , axis = 2 ).astype (np .float64 )))
135
99
if safe :
136
100
edge = safe_step (edge )
137
101
edge = (edge * 255.0 ).clip (0 , 255 ).astype (np .uint8 )
138
102
139
- detected_map = edge [ 0 ]
103
+ detected_map = edge
140
104
141
105
detected_map = HWC3 (detected_map )
142
106
@@ -155,20 +119,3 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
155
119
detected_map = Image .fromarray (detected_map )
156
120
157
121
return detected_map
158
-
159
- def nms (x , t , s ):
160
- x = cv2 .GaussianBlur (x .astype (np .float32 ), (0 , 0 ), s )
161
-
162
- f1 = np .array ([[0 , 0 , 0 ], [1 , 1 , 1 ], [0 , 0 , 0 ]], dtype = np .uint8 )
163
- f2 = np .array ([[0 , 1 , 0 ], [0 , 1 , 0 ], [0 , 1 , 0 ]], dtype = np .uint8 )
164
- f3 = np .array ([[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]], dtype = np .uint8 )
165
- f4 = np .array ([[0 , 0 , 1 ], [0 , 1 , 0 ], [1 , 0 , 0 ]], dtype = np .uint8 )
166
-
167
- y = np .zeros_like (x )
168
-
169
- for f in [f1 , f2 , f3 , f4 ]:
170
- np .putmask (y , cv2 .dilate (x , kernel = f ) == x , x )
171
-
172
- z = np .zeros_like (y , dtype = np .uint8 )
173
- z [y > t ] = 255
174
- return z
0 commit comments