Skip to content

Commit

Permalink
feat(BackgroundMattingV2): add BackgroundMattingV2 C++ (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth committed Apr 10, 2022
1 parent 7035c96 commit 9d5684e
Show file tree
Hide file tree
Showing 12 changed files with 556 additions and 6 deletions.
1 change: 1 addition & 0 deletions examples/lite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ add_lite_executable(lite_insectid cv)
add_lite_executable(lite_plantid cv)
add_lite_executable(lite_modnet cv)
add_lite_executable(lite_modnet_dyn cv)
add_lite_executable(lite_backgroundmattingv2 cv)

153 changes: 153 additions & 0 deletions examples/lite/cv/test_lite_backgroundmattingv2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//
// Created by DefTruth on 2022/4/10.
//

#include "lite/lite.h"

static void test_default()
{
std::string onnx_path = "../../../hub/onnx/cv/BGMv2_mobilenetv2-512x512-full.onnx";
std::string test_src_path = "../../../examples/lite/resources/test_lite_bgmv2_src.png";
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_bgmv2_bgr.png";
std::string save_fgr_path = "../../../logs/test_lite_bgmv2_fgr.jpg";
std::string save_pha_path = "../../../logs/test_lite_bgmv2_pha.jpg";
std::string save_merge_path = "../../../logs/test_lite_bgmv2_merge.jpg";

lite::cv::matting::BackgroundMattingV2 *bgmv2 =
new lite::cv::matting::BackgroundMattingV2(onnx_path, 16); // 16 threads

lite::types::MattingContent content;
cv::Mat src = cv::imread(test_src_path);
cv::Mat bgr = cv::imread(test_bgr_path);

// 1. image matting.
bgmv2->detect(src, bgr, content, true);

if (content.flag)
{
if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat);
if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.);
if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat);
std::cout << "Default Version BackgroundMattingV2 Done!" << std::endl;
}

delete bgmv2;
}

static void test_onnxruntime()
{
#ifdef ENABLE_ONNXRUNTIME
std::string onnx_path = "../../../hub/onnx/cv/BGMv2_mobilenetv2-512x512-full.onnx";
std::string test_src_path = "../../../examples/lite/resources/test_lite_bgmv2_src.png";
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_bgmv2_bgr.png";
std::string save_fgr_path = "../../../logs/test_lite_bgmv2_fgr_onnx.jpg";
std::string save_pha_path = "../../../logs/test_lite_bgmv2_pha_onnx.jpg";
std::string save_merge_path = "../../../logs/test_lite_bgmv2_merge_onnx.jpg";

lite::onnxruntime::cv::matting::BackgroundMattingV2 *bgmv2 =
new lite::onnxruntime::cv::matting::BackgroundMattingV2(onnx_path, 16); // 16 threads

lite::types::MattingContent content;
cv::Mat src = cv::imread(test_src_path);
cv::Mat bgr = cv::imread(test_bgr_path);

// 1. image matting.
bgmv2->detect(src, bgr, content, true);

if (content.flag)
{
if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat);
if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.);
if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat);
std::cout << "ONNXRuntime Version BackgroundMattingV2 Done!" << std::endl;
}

delete bgmv2;
#endif
}

static void test_mnn()
{
#ifdef ENABLE_MNN
std::string mnn_path = "../../../hub/mnn/cv/BGMv2_mobilenetv2-512x512-full.mnn";
std::string test_src_path = "../../../examples/lite/resources/test_lite_bgmv2_src.png";
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_bgmv2_bgr.png";
std::string save_fgr_path = "../../../logs/test_lite_bgmv2_fgr_mnn.jpg";
std::string save_pha_path = "../../../logs/test_lite_bgmv2_pha_mnn.jpg";
std::string save_merge_path = "../../../logs/test_lite_bgmv2_merge_mnn.jpg";

lite::mnn::cv::matting::BackgroundMattingV2 *bgmv2 =
new lite::mnn::cv::matting::BackgroundMattingV2(mnn_path, 16); // 16 threads

lite::types::MattingContent content;
cv::Mat src = cv::imread(test_src_path);
cv::Mat bgr = cv::imread(test_bgr_path);

// 1. image matting.
bgmv2->detect(src, bgr, content, true);

if (content.flag)
{
if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat);
if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.);
if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat);
std::cout << "MNN Version MGMatting Done!" << std::endl;
}

delete bgmv2;
#endif
}

static void test_ncnn()
{
#ifdef ENABLE_NCNN
#endif
}

static void test_tnn()
{
#ifdef ENABLE_TNN
std::string proto_path = "../../../hub/tnn/cv/BGMv2_mobilenetv2-512x512-full.opt.tnnproto";
std::string model_path = "../../../hub/tnn/cv/BGMv2_mobilenetv2-512x512-full.opt.tnnmodel";
std::string test_src_path = "../../../examples/lite/resources/test_lite_bgmv2_src.png";
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_bgmv2_bgr.png";
std::string save_fgr_path = "../../../logs/test_lite_bgmv2_fgr_tnn.jpg";
std::string save_pha_path = "../../../logs/test_lite_bgmv2_pha_tnn.jpg";
std::string save_merge_path = "../../../logs/test_lite_bgmv2_merge_tnn.jpg";

lite::tnn::cv::matting::BackgroundMattingV2 *bgmv2 =
new lite::tnn::cv::matting::BackgroundMattingV2(proto_path, model_path, 16); // 16 threads

lite::types::MattingContent content;
cv::Mat src = cv::imread(test_src_path);
cv::Mat bgr = cv::imread(test_bgr_path);

// 1. image matting.
bgmv2->detect(src, bgr, content, true);

if (content.flag)
{
if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat);
if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.);
if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat);
std::cout << "TNN Version MGMatting Done!" << std::endl;
}

delete bgmv2;
#endif
}

static void test_lite()
{
test_default();
test_onnxruntime();
test_mnn();
test_ncnn();
test_tnn();
}

int main(__unused int argc, __unused char *argv[])
{
test_lite();
return 0;
}
1 change: 1 addition & 0 deletions examples/lite/resources/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test_lite_bgmv2*.png
3 changes: 2 additions & 1 deletion lite/mnn/cv/mnn_backgroundmattingv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace mnncv
// multi inputs.
MNN::Tensor *src_tensor = nullptr;
MNN::Tensor *bgr_tensor = nullptr;
// input size & variant_type, initialize at runtime.
// input size, initialize at runtime.
int input_height;
int input_width;
int dimension_type; // hint only
Expand All @@ -63,6 +63,7 @@ namespace mnncv

private:
void print_debug_string();

private:
void transform(const cv::Mat &mat, const cv::Mat &bgr);

Expand Down
1 change: 0 additions & 1 deletion lite/mnn/cv/mnn_modnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ void MNNMODNet::generate_matting(const std::map<std::string, MNN::Tensor *> &out
float *output_ptr = host_output_tensor.host<float>();

cv::Mat alpha_pred(out_h, out_w, CV_32FC1, output_ptr);
// post process
if (remove_noise) lite::utils::remove_small_connected_area(alpha_pred, 0.05f);
// resize alpha
if (out_h != h || out_w != w)
Expand Down
9 changes: 9 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
#include "lite/ort/cv/plantid.h"
#include "lite/ort/cv/modnet.h"
#include "lite/ort/cv/modnet_dyn.h"
#include "lite/ort/cv/backgroundmattingv2.h"

#endif

Expand Down Expand Up @@ -172,6 +173,7 @@
#include "lite/mnn/cv/mnn_insectid.h"
#include "lite/mnn/cv/mnn_plantid.h"
#include "lite/mnn/cv/mnn_modnet.h"
#include "lite/mnn/cv/mnn_backgroundmattingv2.h"

#endif

Expand Down Expand Up @@ -323,6 +325,7 @@
#include "lite/tnn/cv/tnn_insectid.h"
#include "lite/tnn/cv/tnn_plantid.h"
#include "lite/tnn/cv/tnn_modnet.h"
#include "lite/tnn/cv/tnn_backgroundmattingv2.h"

#endif

Expand Down Expand Up @@ -423,6 +426,7 @@ namespace lite
typedef ortcv::PlantID _PlantID;
typedef ortcv::MODNet _MODNet;
typedef ortcv::MODNetDyn _MODNetDyn;
typedef ortcv::BackgroundMattingV2 _BackgroundMattingV2;
#endif

// 1. classification
Expand Down Expand Up @@ -606,6 +610,7 @@ namespace lite
typedef _MGMatting MGMatting;
typedef _MODNet MODNet;
typedef _MODNetDyn MODNetDyn;
typedef _BackgroundMattingV2 BackgroundMattingV2;
#endif
}
}
Expand Down Expand Up @@ -773,6 +778,7 @@ namespace lite
typedef ortcv::PlantID _ONNXPlantID;
typedef ortcv::MODNet _ONNXMODNet;
typedef ortcv::MODNetDyn _ONNXMODNetDyn;
typedef ortcv::BackgroundMattingV2 _ONNXBackgroundMattingV2;

// 1. classification
namespace classification
Expand Down Expand Up @@ -928,6 +934,7 @@ namespace lite
typedef _ONNXMGMatting MGMatting;
typedef _ONNXMODNet MODNet;
typedef _ONNXMODNetDyn MODNetDyn;
typedef _ONNXBackgroundMattingV2 BackgroundMattingV2;
}
}

Expand Down Expand Up @@ -1056,6 +1063,7 @@ namespace lite
typedef mnncv::MNNRobustVideoMatting RobustVideoMatting;
typedef mnncv::MNNMGMatting MGMatting;
typedef mnncv::MNNMODNet MODNet;
typedef mnncv::MNNBackgroundMattingV2 BackgroundMattingV2;
}

// style transfer
Expand Down Expand Up @@ -1350,6 +1358,7 @@ namespace lite
typedef tnncv::TNNRobustVideoMatting RobustVideoMatting;
typedef tnncv::TNNMGMatting MGMatting;
typedef tnncv::TNNMODNet MODNet;
typedef tnncv::TNNBackgroundMattingV2 BackgroundMattingV2;
}
// style transfer
namespace style
Expand Down
1 change: 0 additions & 1 deletion lite/ncnn/core/ncnn_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ namespace ncnncv
class LITE_EXPORTS NCNNInsectID; // [64] * reference: https://github.com/quarrying/quarrying-insect-id
class LITE_EXPORTS NCNNPlantID; // [65] * reference: https://github.com/quarrying/quarrying-plant-id
class LITE_EXPORTS NCNNMODNet; // [66] * reference: https://github.com/ZHKKKe/MODNet
class LITE_EXPORTS NCNNBackgroundMattingV2; // [67] * reference: https://github.com/PeterL1n/BackgroundMattingV2
}

namespace ncnncv
Expand Down
7 changes: 6 additions & 1 deletion lite/ort/cv/backgroundmattingv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ BackgroundMattingV2::BackgroundMattingV2(const std::string &_onnx_path, unsigned

// 3. type info.
num_inputs = ort_session->GetInputCount(); // 2
input_node_names.resize(num_inputs);
for (unsigned int i = 0; i < num_inputs; ++i)
input_node_names[i] = ort_session->GetInputName(i, allocator);

Ort::TypeInfo input_mat_type_info = ort_session->GetInputTypeInfo(0);
Ort::TypeInfo input_bgr_type_info = ort_session->GetInputTypeInfo(1);
auto input_mat_tensor_info = input_mat_type_info.GetTensorTypeAndShapeInfo();
Expand All @@ -55,7 +59,7 @@ BackgroundMattingV2::BackgroundMattingV2(const std::string &_onnx_path, unsigned
for (unsigned int i = 0; i < num_outputs; ++i)
{
output_node_names[i] = ort_session->GetOutputName(i, allocator);
Ort::TypeInfo type_info = ort_session->GetOutputTypeInfo(0);
Ort::TypeInfo type_info = ort_session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
auto output_shape = tensor_info.GetShape();
output_node_dims.push_back(output_shape);
Expand Down Expand Up @@ -119,6 +123,7 @@ void BackgroundMattingV2::detect(const cv::Mat &mat, const cv::Mat &bgr,
types::MattingContent &content, bool remove_noise,
bool minimum_post_process)
{
if (mat.empty() || bgr.empty()) return;
// 1. make input tensor
auto input_tensors = this->transform(mat, bgr);
// 2. inference
Expand Down
Loading

0 comments on commit 9d5684e

Please sign in to comment.