diff --git a/README.md b/README.md index af33b90..e41daab 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512)) # "lineart_coarse", "lineart_realistic", "mediapipe_face", "mlsd", "normal_bae", "normal_midas", # "openpose", "openpose_face", "openpose_faceonly", "openpose_full", "openpose_hand", # "scribble_hed, "scribble_pidinet", "shuffle", "softedge_hed", "softedge_hedsafe", -# "softedge_pidinet", "softedge_pidsafe"] +# "softedge_pidinet", "softedge_pidsafe", "dwpose"] processor_id = 'scribble_hed' processor = Processor(processor_id) @@ -47,7 +47,7 @@ Each model can be loaded individually by importing and instantiating them as fol from PIL import Image import requests from io import BytesIO -from controlnet_aux import HEDdetector, MidasDetector, MLSDdetector, OpenposeDetector, PidiNetDetector, NormalBaeDetector, LineartDetector, LineartAnimeDetector, CannyDetector, ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, SamDetector, LeresDetector +from controlnet_aux import HEDdetector, MidasDetector, MLSDdetector, OpenposeDetector, PidiNetDetector, NormalBaeDetector, LineartDetector, LineartAnimeDetector, CannyDetector, ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, SamDetector, LeresDetector, DWposeDetector # load image url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png" @@ -69,6 +69,15 @@ sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkp mobile_sam = SamDetector.from_pretrained("dhkim2810/MobileSAM", model_type="vit_t", filename="mobile_sam.pt") leres = LeresDetector.from_pretrained("lllyasviel/Annotators") +# download sepecify configs and ckpts +# det_config: ./src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py +# det_ckpt: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth +# pose_config: ./src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py +# pose_ckpt: https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth +import torch +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +dwpose = DWposeDetector(det_config, det_ckpt, pose_config, pose_ckpt, device) + # instantiate canny = CannyDetector() content = ContentShuffleDetector() @@ -91,4 +100,5 @@ processed_image_leres = leres(img) processed_image_canny = canny(img) processed_image_content = content(img) processed_image_mediapipe_face = face_detector(img) +processed_image_dwpose = dwpose(img) ``` diff --git a/src/controlnet_aux/dwpose/__init__.py b/src/controlnet_aux/dwpose/__init__.py new file mode 100644 index 0000000..92f91e9 --- /dev/null +++ b/src/controlnet_aux/dwpose/__init__.py @@ -0,0 +1,87 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import cv2 +import torch +import numpy as np +from PIL import Image + +from . import util +from .wholebody import Wholebody + + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + canvas = util.draw_bodypose(canvas, candidate, subset) + canvas = util.draw_handpose(canvas, hands) + canvas = util.draw_facepose(canvas, faces) + + return canvas + +class DWposeDetector: + def __init__(self, det_config, det_ckpt, pose_config, pose_ckpt, device): + + self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) + + def __call__(self, oriImg, output_type="pil", detect_resolution=512, image_resolution=512): + + oriImg = oriImg.copy() + input_image = cv2.cvtColor(np.array(oriImg), cv2.COLOR_RGB2BGR) + + input_image = util.HWC3(input_image) + input_image = util.resize_image(input_image, detect_resolution) + H, W, C = input_image.shape + + with torch.no_grad(): + candidate, subset = self.pose_estimation(input_image) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + detected_map = draw_pose(pose, H, W) + detected_map = util.HWC3(detected_map) + + img = util.resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)) + + return detected_map \ No newline at end of file diff --git a/src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py b/src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py new file mode 100644 index 0000000..d45abe6 --- /dev/null +++ b/src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py @@ -0,0 +1,257 @@ +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(288, 384), + sigma=(6., 6.93), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa + )), + head=dict( + type='RTMCCHead', + in_channels=1024, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(9, 12), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = '/data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/src/controlnet_aux/dwpose/util.py b/src/controlnet_aux/dwpose/util.py new file mode 100644 index 0000000..0b290b5 --- /dev/null +++ b/src/controlnet_aux/dwpose/util.py @@ -0,0 +1,334 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + # (person_number*2, 21, 2) + for i in range(len(all_hand_peaks)): + peaks = all_hand_peaks[i] + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for _, keyponit in enumerate(peaks): + x, y = keyponit + + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/src/controlnet_aux/dwpose/wholebody.py b/src/controlnet_aux/dwpose/wholebody.py new file mode 100644 index 0000000..ba665b2 --- /dev/null +++ b/src/controlnet_aux/dwpose/wholebody.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from . import util +import cv2 +import mmcv +import torch +import matplotlib.pyplot as plt +from mmpose.apis import inference_topdown +from mmpose.apis import init_model as init_pose_estimator +from mmpose.evaluation.functional import nms +from mmpose.utils import adapt_mmdet_pipeline +from mmpose.structures import merge_data_samples + +from mmdet.apis import inference_detector, init_detector + + +class Wholebody: + def __init__(self, + det_config=None, det_ckpt=None, + pose_config=None, pose_ckpt=None, + device="cpu"): + + if det_ckpt is None: + det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' + + if pose_ckpt is None: + pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" + + # build detector + self.detector = init_detector(det_config, det_ckpt, device=device) + self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) + + # build pose estimator + self.pose_estimator = init_pose_estimator( + pose_config, + pose_ckpt, + device=device) + + def __call__(self, oriImg): + # predict bbox + det_result = inference_detector(self.detector, oriImg) + pred_instance = det_result.pred_instances.cpu().numpy() + bboxes = np.concatenate( + (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) + bboxes = bboxes[np.logical_and(pred_instance.labels == 0, + pred_instance.scores > 0.5)] + + # set NMS threshold + bboxes = bboxes[nms(bboxes, 0.7), :4] + + # predict keypoints + if len(bboxes) == 0: + pose_results = inference_topdown(self.pose_estimator, oriImg) + else: + pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) + preds = merge_data_samples(pose_results) + preds = preds.pred_instances + + # preds = pose_results[0].pred_instances + keypoints = preds.get('transformed_keypoints', + preds.keypoints) + if 'keypoint_scores' in preds: + scores = preds.keypoint_scores + else: + scores = np.ones(keypoints.shape[:-1]) + + if 'keypoints_visible' in preds: + visible = preds.keypoints_visible + else: + visible = np.ones(keypoints.shape[:-1]) + keypoints_info = np.concatenate( + (keypoints, scores[..., None], visible[..., None]), + axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores, visible = keypoints_info[ + ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + + return keypoints, scores diff --git a/src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py b/src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py new file mode 100644 index 0000000..7b4cb5a --- /dev/null +++ b/src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py @@ -0,0 +1,245 @@ +img_scale = (640, 640) # width, height + +# model settings +model = dict( + type='YOLOX', + data_preprocessor=dict( + type='DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + backbone=dict( + type='CSPDarknet', + deepen_factor=1.0, + widen_factor=1.0, + out_indices=(2, 3, 4), + use_depthwise=False, + spp_kernal_sizes=(5, 9, 13), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + ), + neck=dict( + type='YOLOXPAFPN', + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + use_depthwise=False, + upsample_cfg=dict(scale_factor=2, mode='nearest'), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish')), + bbox_head=dict( + type='YOLOXHead', + num_classes=80, + in_channels=256, + feat_channels=256, + stacked_convs=2, + strides=(8, 16, 32), + use_depthwise=False, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox=dict( + type='IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +# dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.1, 2), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict( + type='MixUp', + img_scale=img_scale, + ratio_range=(0.8, 1.6), + pad_val=114.0), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + # According to the official implementation, multi-scale + # training is not considered here but in the + # 'mmdet/models/detectors/yolox.py'. + # Resize and Pad are for the last 15 epochs when Mosaic, + # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + # If the image is three-channel, the pad value needs + # to be set separately for each channel. + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='PackDetInputs') +] + +train_dataset = dict( + # use MultiImageMixDataset wrapper to support mosaic and mixup + type='MultiImageMixDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=[ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True) + ], + filter_cfg=dict(filter_empty_gt=False, min_size=32), + backend_args=backend_args), + pipeline=train_pipeline) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) +val_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + backend_args=backend_args) +test_evaluator = val_evaluator + +# training settings +max_epochs = 300 +num_last_epochs = 15 +interval = 10 + +train_cfg = dict(max_epochs=max_epochs, val_interval=interval) + +# optimizer +# default 8 gpu +base_lr = 0.01 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, + nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) + +# learning rate +param_scheduler = [ + dict( + # use quadratic formula to warm up 5 epochs + # and lr is updated by iteration + # TODO: fix default scope in get function + type='mmdet.QuadraticWarmupLR', + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + # use cosine lr from 5 to 285 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + # use fixed lr during last 15 epochs + type='ConstantLR', + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ) +] + +default_hooks = dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + )) + +custom_hooks = [ + dict( + type='YOLOXModeSwitchHook', + num_last_epochs=num_last_epochs, + priority=48), + dict(type='SyncNormHook', priority=48), + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0001, + update_buffers=True, + priority=49) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/tests/test_controlnet_aux.py b/tests/test_controlnet_aux.py index 36e7239..28a389c 100644 --- a/tests/test_controlnet_aux.py +++ b/tests/test_controlnet_aux.py @@ -12,7 +12,7 @@ LineartDetector, MediapipeFaceDetector, MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector, SamDetector, - ZoeDetector) + ZoeDetector, DWposeDetector) OUTPUT_DIR = "tests/outputs" @@ -119,3 +119,9 @@ def test_shuffle(img): def test_zoe(img): zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") common("zoe", zoe, img) + +def test_dwpose(img, det_config, det_ckpt, pose_config, pose_ckpt): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + dwpose = DWposeDetector(det_config, det_ckpt, pose_config, pose_ckpt, device) + common("dwpose", dwpose, img) + return_pil("dwpose", dwpose, img)