Skip to content

Commit

Permalink
feat(lite): add FaceParsingBiSeNet model ORT/MNN C++ (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth committed Jul 2, 2022
1 parent a02d1d0 commit a6428c2
Show file tree
Hide file tree
Showing 12 changed files with 516 additions and 17 deletions.
40 changes: 40 additions & 0 deletions examples/lite/cv/test_lite_face_parsing_bisenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,52 @@ static void test_mnn()
static void test_ncnn()
{
#ifdef ENABLE_NCNN
std::string proto_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.param";
std::string bin_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.bin";
std::string test_img_path = "../../../examples/lite/resources/test_lite_face_parsing.png";
std::string save_img_path = "../../../logs/test_lite_face_parsing_bisenet_ncnn.jpg";

lite::ncnn::cv::segmentation::FaceParsingBiSeNet *face_parsing_bisenet =
new lite::ncnn::cv::segmentation::FaceParsingBiSeNet(
proto_path, bin_path, 4, 512, 512);

lite::types::FaceParsingContent content;
cv::Mat img_bgr = cv::imread(test_img_path);
face_parsing_bisenet->detect(img_bgr, content);

if (content.flag)
{
if (!content.merge.empty()) cv::imwrite(save_img_path, content.merge);
std::cout << "NCNN Version FaceParsingBiSeNet Done!" << std::endl;
}

delete face_parsing_bisenet;
#endif
}

static void test_tnn()
{
#ifdef ENABLE_TNN
std::string proto_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.tnnproto";
std::string model_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.tnnmodel";
std::string test_img_path = "../../../examples/lite/resources/test_lite_face_parsing.png";
std::string save_img_path = "../../../logs/test_lite_face_parsing_bisenet_tnn.jpg";

lite::tnn::cv::segmentation::FaceParsingBiSeNet *face_parsing_bisenet =
new lite::tnn::cv::segmentation::FaceParsingBiSeNet(
proto_path, model_path, 4);

lite::types::FaceParsingContent content;
cv::Mat img_bgr = cv::imread(test_img_path);
face_parsing_bisenet->detect(img_bgr, content);

if (content.flag)
{
if (!content.merge.empty()) cv::imwrite(save_img_path, content.merge);
std::cout << "TNN Version FaceParsingBiSeNet Done!" << std::endl;
}

delete face_parsing_bisenet;
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions examples/lite/cv/test_lite_modnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ static void test_ncnn()
{
#ifdef ENABLE_NCNN
std::string proto_path = "../../../hub/ncnn/cv/modnet_photographic_portrait_matting-512x512.opt.param";
std::string model_path = "../../../hub/ncnn/cv/modnet_photographic_portrait_matting-512x512.opt.bin";
std::string bin_path = "../../../hub/ncnn/cv/modnet_photographic_portrait_matting-512x512.opt.bin";
std::string test_img_path = "../../../examples/lite/resources/test_lite_matting_input.jpg";
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_matting_bgr.jpg";
std::string save_fgr_path = "../../../logs/test_lite_modnet_fgr_ncnn.jpg";
Expand All @@ -152,7 +152,7 @@ static void test_ncnn()
std::string save_swap_path = "../../../logs/test_lite_modnet_swap_ncnn.jpg";

lite::ncnn::cv::matting::MODNet *modnet =
new lite::ncnn::cv::matting::MODNet(proto_path, model_path, 16, 512, 512); // 16 threads
new lite::ncnn::cv::matting::MODNet(proto_path, bin_path, 16, 512, 512); // 16 threads

lite::types::MattingContent content;
cv::Mat img_bgr = cv::imread(test_img_path);
Expand Down
1 change: 0 additions & 1 deletion lite/mnn/cv/mnn_face_parsing_bisenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//

#include "mnn_face_parsing_bisenet.h"
#include "lite/utils.h"

using mnncv::MNNFaceParsingBiSeNet;

Expand Down
4 changes: 4 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
#include "lite/ncnn/cv/ncnn_modnet.h"
#include "lite/ncnn/cv/ncnn_female_photo2cartoon.h"
#include "lite/ncnn/cv/ncnn_yolov6.h"
#include "lite/ncnn/cv/ncnn_face_parsing_bisenet.h"

#endif

Expand Down Expand Up @@ -358,6 +359,7 @@
#include "lite/tnn/cv/tnn_head_seg.h"
#include "lite/tnn/cv/tnn_female_photo2cartoon.h"
#include "lite/tnn/cv/tnn_yolov6.h"
#include "lite/tnn/cv/tnn_face_parsing_bisenet.h"

#endif

Expand Down Expand Up @@ -1324,6 +1326,7 @@ namespace lite
{
typedef ncnncv::NCNNDeepLabV3ResNet101 DeepLabV3ResNet101;
typedef ncnncv::NCNNFCNResNet101 FCNResNet101;
typedef ncnncv::NCNNFaceParsingBiSeNet FaceParsingBiSeNet;
}
// reid
namespace reid
Expand Down Expand Up @@ -1474,6 +1477,7 @@ namespace lite
typedef tnncv::TNNDeepLabV3ResNet101 DeepLabV3ResNet101;
typedef tnncv::TNNFCNResNet101 FCNResNet101;
typedef tnncv::TNNHeadSeg HeadSeg;
typedef tnncv::TNNFaceParsingBiSeNet FaceParsingBiSeNet;
}
// reid
namespace reid
Expand Down
12 changes: 6 additions & 6 deletions lite/ncnn/core/ncnn_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
namespace ncnncv
{
class LITE_EXPORTS NCNNNanoDet; // [0] * reference: https://github.com/RangiLyu/nanodet
class LITE_EXPORTS NCNNNanoDetEfficientNetLite; // [1] * reference: https://github.com/RangiLyu/nanodet
class LITE_EXPORTS NCNNNanoDetEfficientNetLite; // [1] * reference: https://github.com/RangiLyu/nanodet
class LITE_EXPORTS NCNNNanoDetDepreciated; // [2] * reference: https://github.com/RangiLyu/nanodet
class LITE_EXPORTS NCNNNanoDetEfficientNetLiteDepreciated; // [3] * reference: https://github.com/RangiLyu/nanodet
class LITE_EXPORTS NCNNNanoDetEfficientNetLiteDepreciated; // [3] * reference: https://github.com/RangiLyu/nanodet
class LITE_EXPORTS NCNNRobustVideoMatting; // [4] * reference: https://github.com/PeterL1n/RobustVideoMatting
class LITE_EXPORTS NCNNYoloX; // [5] * reference: https://github.com/Megvii-BaseDetection/YOLOX
class LITE_EXPORTS NCNNYOLOP; // [6] * reference: https://github.com/hustvl/YOLOP
Expand Down Expand Up @@ -51,11 +51,11 @@ namespace ncnncv
class LITE_EXPORTS NCNNAgeGoogleNet; // [36] * reference: https://github.com/onnx/models/tree/master/vision/body_analysis/age_gender
class LITE_EXPORTS NCNNGenderGoogleNet; // [37] * reference: https://github.com/onnx/models/tree/master/vision/body_analysis/age_gender
class LITE_EXPORTS NCNNEmotionFerPlus; // [38] * reference: https://github.com/onnx/models/blob/master/vision/body_analysis/emotion_ferplus
class LITE_EXPORTS NCNNEfficientEmotion7; // [39] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
class LITE_EXPORTS NCNNEfficientEmotion8; // [40] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
class LITE_EXPORTS NCNNEfficientEmotion7; // [39] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
class LITE_EXPORTS NCNNEfficientEmotion8; // [40] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
class LITE_EXPORTS NCNNMobileEmotion7; // [41] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
class LITE_EXPORTS NCNNEfficientNetLite4; // [42] * reference: https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4
class LITE_EXPORTS NCNNShuffleNetV2; // [43] * reference: https://github.com/onnx/models/blob/master/vision/classification/shufflenet
class LITE_EXPORTS NCNNEfficientNetLite4; // [42] * reference: https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4
class LITE_EXPORTS NCNNShuffleNetV2; // [43] * reference: https://github.com/onnx/models/blob/master/vision/classification/shufflenet
class LITE_EXPORTS NCNNDenseNet; // [44] * reference: https://pytorch.org/hub/pytorch_vision_densenet/
class LITE_EXPORTS NCNNGhostNet; // [45] * reference:https://pytorch.org/hub/pytorch_vision_ghostnet/
class LITE_EXPORTS NCNNHdrDNet; // [46] * reference: https://pytorch.org/hub/pytorch_vision_hardnet/
Expand Down
182 changes: 182 additions & 0 deletions lite/ncnn/cv/ncnn_face_parsing_bisenet.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
//
// Created by DefTruth on 2022/7/2.
//

#include "ncnn_face_parsing_bisenet.h"

using ncnncv::NCNNFaceParsingBiSeNet;

NCNNFaceParsingBiSeNet::NCNNFaceParsingBiSeNet(const std::string &_param_path,
const std::string &_bin_path,
unsigned int _num_threads,
unsigned int _input_height,
unsigned int _input_width) :
BasicNCNNHandler(_param_path, _bin_path, _num_threads),
input_height(_input_height), input_width(_input_width)
{
}

void NCNNFaceParsingBiSeNet::transform(const cv::Mat &mat, ncnn::Mat &in)
{
cv::Mat mat_rs;
cv::resize(mat, mat_rs, cv::Size(input_width, input_height));
// will do deepcopy inside ncnn
in = ncnn::Mat::from_pixels(mat_rs.data, ncnn::Mat::PIXEL_BGR2RGB, input_width, input_height);
in.substract_mean_normalize(mean_vals, norm_vals);
}

void NCNNFaceParsingBiSeNet::detect(const cv::Mat &mat, types::FaceParsingContent &content,
bool minimum_post_process)
{
if (mat.empty()) return;

// 1. make input tensor
ncnn::Mat input;
this->transform(mat, input);
// 2. inference & extract
auto extractor = net->create_extractor();
extractor.set_light_mode(false); // default
extractor.set_num_threads(num_threads);
extractor.input("input", input);
// 3. generate mask
this->generate_mask(extractor, mat, content, minimum_post_process);
}

static inline uchar __argmax_find(float *mutable_ptr, const unsigned int &step)
{
std::vector<float> logits(19, 0.f);
for (unsigned int i = 0; i < 19; ++i)
logits[i] = *(mutable_ptr + i * step);
uchar label = 0;
float max_logit = logits[0];
for (unsigned int i = 1; i < 19; ++i)
{
if (logits[i] > max_logit)
{
max_logit = logits[i];
label = (uchar) i;
}
}
return label;
}

static const uchar part_colors[20][3] = {
{255, 0, 0},
{255, 85, 0},
{255, 170, 0},
{255, 0, 85},
{255, 0, 170},
{0, 255, 0},
{85, 255, 0},
{170, 255, 0},
{0, 255, 85},
{0, 255, 170},
{0, 0, 255},
{85, 0, 255},
{170, 0, 255},
{0, 85, 255},
{0, 170, 255},
{255, 255, 0},
{255, 255, 85},
{255, 255, 170},
{255, 0, 255},
{255, 85, 255}
};

void NCNNFaceParsingBiSeNet::generate_mask(ncnn::Extractor &extractor, const cv::Mat &mat,
types::FaceParsingContent &content,
bool minimum_post_process)
{
ncnn::Mat output;
extractor.extract("out", output);
#ifdef LITENCNN_DEBUG
BasicNCNNHandler::print_shape(output, "out");
#endif
const unsigned int h = mat.rows;
const unsigned int w = mat.cols;

const unsigned int out_h = input_height;
const unsigned int out_w = input_width;
const unsigned int channel_step = out_h * out_w;

float *output_ptr = (float *) output.data;
std::vector<uchar> elements(channel_step, 0); // allocate
for (unsigned int i = 0; i < channel_step; ++i)
elements[i] = __argmax_find(output_ptr + i, channel_step);

cv::Mat label(out_h, out_w, CV_8UC1, elements.data());

if (!minimum_post_process)
{
const uchar *label_ptr = label.data;
cv::Mat color_mat(out_h, out_w, CV_8UC3, cv::Scalar(255, 255, 255));
for (unsigned int i = 0; i < color_mat.rows; ++i)
{
cv::Vec3b *p = color_mat.ptr<cv::Vec3b>(i);
for (unsigned int j = 0; j < color_mat.cols; ++j)
{
if (label_ptr[i * out_w + j] == 0) continue;
p[j][0] = part_colors[label_ptr[i * out_w + j]][0];
p[j][1] = part_colors[label_ptr[i * out_w + j]][1];
p[j][2] = part_colors[label_ptr[i * out_w + j]][2];
}
}
if (out_h != h || out_w != w)
cv::resize(color_mat, color_mat, cv::Size(w, h));
cv::addWeighted(mat, 0.4, color_mat, 0.6, 0., content.merge);
}
if (out_h != h || out_w != w) cv::resize(label, label, cv::Size(w, h));

content.label = label;
content.flag = true;
}


















































43 changes: 43 additions & 0 deletions lite/ncnn/cv/ncnn_face_parsing_bisenet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//
// Created by DefTruth on 2022/7/2.
//

#ifndef LITE_AI_TOOLKIT_NCNN_CV_NCNN_FACE_PARSING_BISENET_H
#define LITE_AI_TOOLKIT_NCNN_CV_NCNN_FACE_PARSING_BISENET_H

#include "lite/ncnn/core/ncnn_core.h"

namespace ncnncv
{
class LITE_EXPORTS NCNNFaceParsingBiSeNet : public BasicNCNNHandler
{
public:
explicit NCNNFaceParsingBiSeNet(const std::string &_param_path,
const std::string &_bin_path,
unsigned int _num_threads = 1,
unsigned int _input_height = 512,
unsigned int _input_width = 512);

~NCNNFaceParsingBiSeNet() override = default;

private:
const int input_height;
const int input_width;
const float mean_vals[3] = {0.485f * 255.f, 0.456f * 255.f, 0.406f * 255.f}; // RGB
const float norm_vals[3] = {1.f / (0.229f * 255.f), 1.f / (0.224f * 255.f), 1.f / (0.225f * 255.f)};

private:
void transform(const cv::Mat &mat, ncnn::Mat &in) override;

void generate_mask(ncnn::Extractor &extractor,
const cv::Mat &mat, types::FaceParsingContent &content,
bool minimum_post_process = false);

public:
void detect(const cv::Mat &mat, types::FaceParsingContent &content,
bool minimum_post_process = false);

};
}

#endif //LITE_AI_TOOLKIT_NCNN_CV_NCNN_FACE_PARSING_BISENET_H
3 changes: 1 addition & 2 deletions lite/ncnn/cv/ncnn_modnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ namespace ncnncv
const std::string &_bin_path,
unsigned int _num_threads = 1,
unsigned int _input_height = 512,
unsigned int _input_width = 512
);
unsigned int _input_width = 512);

~NCNNMODNet() override = default;

Expand Down
1 change: 0 additions & 1 deletion lite/ort/cv/face_parsing_bisenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "face_parsing_bisenet.h"
#include "lite/ort/core/ort_utils.h"
#include "lite/utils.h"

using ortcv::FaceParsingBiSeNet;

Expand Down
Loading

0 comments on commit a6428c2

Please sign in to comment.