Skip to content

Commit 114147d

Browse files
author
zhouzhengguang
committed
update
1 parent bfc9224 commit 114147d

File tree

4 files changed

+18
-43
lines changed

4 files changed

+18
-43
lines changed

src/controlnet_aux/dwpose/__init__.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
from PIL import Image
1414

15+
from ..util import HWC3, resize_image
1516
from . import util
1617
from .wholebody import Wholebody
1718

@@ -31,17 +32,16 @@ def draw_pose(pose, H, W):
3132
return canvas
3233

3334
class DWposeDetector:
34-
def __init__(self, det_config, det_ckpt, pose_config, pose_ckpt, device):
35+
def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"):
3536

3637
self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device)
3738

38-
def __call__(self, oriImg, output_type="pil", detect_resolution=512, image_resolution=512):
39+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
3940

40-
oriImg = oriImg.copy()
41-
input_image = cv2.cvtColor(np.array(oriImg), cv2.COLOR_RGB2BGR)
41+
input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
4242

43-
input_image = util.HWC3(input_image)
44-
input_image = util.resize_image(input_image, detect_resolution)
43+
input_image = HWC3(input_image)
44+
input_image = resize_image(input_image, detect_resolution)
4545
H, W, C = input_image.shape
4646

4747
with torch.no_grad():
@@ -74,9 +74,9 @@ def __call__(self, oriImg, output_type="pil", detect_resolution=512, image_resol
7474
pose = dict(bodies=bodies, hands=hands, faces=faces)
7575

7676
detected_map = draw_pose(pose, H, W)
77-
detected_map = util.HWC3(detected_map)
77+
detected_map = HWC3(detected_map)
7878

79-
img = util.resize_image(input_image, image_resolution)
79+
img = resize_image(input_image, image_resolution)
8080
H, W, C = img.shape
8181

8282
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/dwpose/util.py

-31
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,6 @@
77
eps = 0.01
88

99

10-
def HWC3(x):
11-
assert x.dtype == np.uint8
12-
if x.ndim == 2:
13-
x = x[:, :, None]
14-
assert x.ndim == 3
15-
H, W, C = x.shape
16-
assert C == 1 or C == 3 or C == 4
17-
if C == 3:
18-
return x
19-
if C == 1:
20-
return np.concatenate([x, x, x], axis=2)
21-
if C == 4:
22-
color = x[:, :, 0:3].astype(np.float32)
23-
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
24-
y = color * alpha + 255.0 * (1.0 - alpha)
25-
y = y.clip(0, 255).astype(np.uint8)
26-
return y
27-
28-
def resize_image(input_image, resolution):
29-
H, W, C = input_image.shape
30-
H = float(H)
31-
W = float(W)
32-
k = float(resolution) / min(H, W)
33-
H *= k
34-
W *= k
35-
H = int(np.round(H / 64.0)) * 64
36-
W = int(np.round(W / 64.0)) * 64
37-
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38-
return img
39-
40-
4110
def smart_resize(x, s):
4211
Ht, Wt = s
4312
if x.ndim == 2:

src/controlnet_aux/dwpose/wholebody.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
23
import numpy as np
34
from . import util
45
import cv2
@@ -19,7 +20,13 @@ def __init__(self,
1920
det_config=None, det_ckpt=None,
2021
pose_config=None, pose_ckpt=None,
2122
device="cpu"):
22-
23+
24+
if det_config is None:
25+
det_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py")
26+
27+
if pose_config is None:
28+
pose_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py")
29+
2330
if det_ckpt is None:
2431
det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
2532

tests/test_controlnet_aux.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ def test_zoe(img):
120120
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
121121
common("zoe", zoe, img)
122122

123-
def test_dwpose(img, det_config, det_ckpt, pose_config, pose_ckpt):
124-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
125-
dwpose = DWposeDetector(det_config, det_ckpt, pose_config, pose_ckpt, device)
123+
def test_dwpose(img):
124+
dwpose = DWposeDetector()
126125
common("dwpose", dwpose, img)
127126
return_pil("dwpose", dwpose, img)

0 commit comments

Comments
 (0)