Skip to content

Commit 98dc704

Browse files
Merge pull request huggingface#69 from haofanwang/master
Add DWposeDetector
2 parents 1e6bdc0 + 130da6f commit 98dc704

File tree

9 files changed

+1054
-5
lines changed

9 files changed

+1054
-5
lines changed

README.md

+20-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ All credit & copyright goes to https://github.com/lllyasviel .
1212
pip install controlnet-aux==0.0.6
1313
```
1414

15+
To support DWPose which is dependent on MMDetection, MMCV and MMPose
16+
```
17+
pip install -U openmim
18+
mim install mmengine
19+
mim install "mmcv>=2.0.1"
20+
mim install "mmdet>=3.1.0"
21+
mim install "mmpose>=1.1.0"
22+
```
1523
## Usage
1624

1725

@@ -35,7 +43,7 @@ img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512))
3543
# "lineart_coarse", "lineart_realistic", "mediapipe_face", "mlsd", "normal_bae", "normal_midas",
3644
# "openpose", "openpose_face", "openpose_faceonly", "openpose_full", "openpose_hand",
3745
# "scribble_hed, "scribble_pidinet", "shuffle", "softedge_hed", "softedge_hedsafe",
38-
# "softedge_pidinet", "softedge_pidsafe"]
46+
# "softedge_pidinet", "softedge_pidsafe", "dwpose"]
3947
processor_id = 'scribble_hed'
4048
processor = Processor(processor_id)
4149

@@ -47,7 +55,7 @@ Each model can be loaded individually by importing and instantiating them as fol
4755
from PIL import Image
4856
import requests
4957
from io import BytesIO
50-
from controlnet_aux import HEDdetector, MidasDetector, MLSDdetector, OpenposeDetector, PidiNetDetector, NormalBaeDetector, LineartDetector, LineartAnimeDetector, CannyDetector, ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, SamDetector, LeresDetector
58+
from controlnet_aux import HEDdetector, MidasDetector, MLSDdetector, OpenposeDetector, PidiNetDetector, NormalBaeDetector, LineartDetector, LineartAnimeDetector, CannyDetector, ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, SamDetector, LeresDetector, DWposeDetector
5159

5260
# load image
5361
url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png"
@@ -69,6 +77,15 @@ sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkp
6977
mobile_sam = SamDetector.from_pretrained("dhkim2810/MobileSAM", model_type="vit_t", filename="mobile_sam.pt")
7078
leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
7179

80+
# specify configs, ckpts and device, or it will be downloaded automatically and use cpu by default
81+
# det_config: ./src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py
82+
# det_ckpt: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth
83+
# pose_config: ./src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py
84+
# pose_ckpt: https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth
85+
import torch
86+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87+
dwpose = DWposeDetector(det_config=det_config, det_ckpt=det_ckpt, pose_config=pose_config, pose_ckpt=pose_ckpt, device=device)
88+
7289
# instantiate
7390
canny = CannyDetector()
7491
content = ContentShuffleDetector()
@@ -91,4 +108,5 @@ processed_image_leres = leres(img)
91108
processed_image_canny = canny(img)
92109
processed_image_content = content(img)
93110
processed_image_mediapipe_face = face_detector(img)
111+
processed_image_dwpose = dwpose(img)
94112
```

src/controlnet_aux/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from .canny import CannyDetector
1515
from .mediapipe_face import MediapipeFaceDetector
1616
from .segment_anything import SamDetector
17-
from .shuffle import ContentShuffleDetector
17+
from .shuffle import ContentShuffleDetector
18+
from .dwpose import DWposeDetector

src/controlnet_aux/dwpose/__init__.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Openpose
2+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4+
# 3rd Edited by ControlNet
5+
# 4th Edited by ControlNet (added face and correct hands)
6+
7+
import os
8+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9+
10+
import cv2
11+
import torch
12+
import numpy as np
13+
from PIL import Image
14+
15+
from ..util import HWC3, resize_image
16+
from . import util
17+
from .wholebody import Wholebody
18+
19+
20+
def draw_pose(pose, H, W):
21+
bodies = pose['bodies']
22+
faces = pose['faces']
23+
hands = pose['hands']
24+
candidate = bodies['candidate']
25+
subset = bodies['subset']
26+
27+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
28+
canvas = util.draw_bodypose(canvas, candidate, subset)
29+
canvas = util.draw_handpose(canvas, hands)
30+
canvas = util.draw_facepose(canvas, faces)
31+
32+
return canvas
33+
34+
class DWposeDetector:
35+
def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"):
36+
37+
self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device)
38+
39+
def to(self, device):
40+
self.pose_estimation.to(device)
41+
return self
42+
43+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
44+
45+
input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
46+
47+
input_image = HWC3(input_image)
48+
input_image = resize_image(input_image, detect_resolution)
49+
H, W, C = input_image.shape
50+
51+
with torch.no_grad():
52+
candidate, subset = self.pose_estimation(input_image)
53+
nums, keys, locs = candidate.shape
54+
candidate[..., 0] /= float(W)
55+
candidate[..., 1] /= float(H)
56+
body = candidate[:,:18].copy()
57+
body = body.reshape(nums*18, locs)
58+
score = subset[:,:18]
59+
60+
for i in range(len(score)):
61+
for j in range(len(score[i])):
62+
if score[i][j] > 0.3:
63+
score[i][j] = int(18*i+j)
64+
else:
65+
score[i][j] = -1
66+
67+
un_visible = subset<0.3
68+
candidate[un_visible] = -1
69+
70+
foot = candidate[:,18:24]
71+
72+
faces = candidate[:,24:92]
73+
74+
hands = candidate[:,92:113]
75+
hands = np.vstack([hands, candidate[:,113:]])
76+
77+
bodies = dict(candidate=body, subset=score)
78+
pose = dict(bodies=bodies, hands=hands, faces=faces)
79+
80+
detected_map = draw_pose(pose, H, W)
81+
detected_map = HWC3(detected_map)
82+
83+
img = resize_image(input_image, image_resolution)
84+
H, W, C = img.shape
85+
86+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
87+
88+
if output_type == "pil":
89+
detected_map = Image.fromarray(detected_map)
90+
91+
return detected_map

0 commit comments

Comments
 (0)