Skip to content

Commit

Permalink
feat(MODNet): add MODNet for trimap-free human matting (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth committed Mar 30, 2022
1 parent c646304 commit 6be4f5e
Show file tree
Hide file tree
Showing 7 changed files with 642 additions and 1 deletion.
177 changes: 177 additions & 0 deletions lite/mnn/cv/mnn_modnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,180 @@
// Created by DefTruth on 2022/3/27.
//

#include "mnn_modnet.h"

using mnncv::MNNMODNet;

MNNMODNet::MNNMODNet(const std::string &_mnn_path, unsigned int _num_threads)
: BasicMNNHandler(_mnn_path, _num_threads)
{
initialize_pretreat();
}

inline void MNNMODNet::initialize_pretreat()
{
pretreat = std::shared_ptr<MNN::CV::ImageProcess>(
MNN::CV::ImageProcess::create(
MNN::CV::BGR,
MNN::CV::RGB,
mean_vals, 3,
norm_vals, 3
)
);
}

void MNNMODNet::transform(const cv::Mat &mat)
{
cv::Mat canvas;
cv::resize(mat, canvas, cv::Size(input_width, input_height));
// (1,3,512,512) deepcopy inside
pretreat->convert(canvas.data, input_width, input_height, canvas.step[0], input_tensor);
}

void MNNMODNet::detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise)
{
if (mat.empty()) return;
// 1. make input tensor
this->transform(mat);
// 2. inference
mnn_interpreter->runSession(mnn_session);
auto output_tensors = mnn_interpreter->getSessionOutputAll(mnn_session);
// 3. generate matting
this->generate_matting(output_tensors, mat, content, remove_noise);
}

// user-friendly method for background swap.
void MNNMODNet::swap_background(const cv::Mat &fg_mat,
const cv::Mat &pha_mat,
const cv::Mat &bg_mat,
cv::Mat &out_mat)
{
if (fg_mat.empty() || pha_mat.empty() || bg_mat.empty()) return;
const unsigned int fg_h = fg_mat.rows;
const unsigned int fg_w = fg_mat.cols;
const unsigned int bg_h = bg_mat.rows;
const unsigned int bg_w = bg_mat.cols;
const unsigned int ph_h = pha_mat.rows;
const unsigned int ph_w = pha_mat.cols;
cv::Mat bg_mat_copy, pha_mat_copy, fg_mat_copy;
if (bg_h != fg_h || bg_w != fg_w)
cv::resize(bg_mat, bg_mat_copy, cv::Size(fg_w, fg_h));
else bg_mat_copy = bg_mat; // ref only.
if (ph_h != fg_h || ph_w != fg_w)
cv::resize(pha_mat, pha_mat_copy, cv::Size(fg_w, fg_h));
else pha_mat_copy = pha_mat; // ref only.
// convert pha_mat_copy to 3 channels.
if (pha_mat_copy.channels() == 1)
cv::cvtColor(pha_mat_copy, pha_mat_copy, cv::COLOR_GRAY2BGR); // 0.~1.
// convert mats to float32 points.
fg_mat.convertTo(fg_mat_copy, CV_32FC3); // 0.~255.
bg_mat_copy.convertTo(bg_mat_copy, CV_32FC3); // 0.~255.
pha_mat_copy.convertTo(pha_mat_copy, CV_32FC3); // 0.~1. assert pha_mat_copy is float32.
// element wise operations.
cv::Mat rest = 1.f - pha_mat_copy;
// out_mat = fg_mat_copy.mul(pha_mat_copy) + bg_mat_copy.mul(rest);
cv::add(fg_mat_copy.mul(pha_mat_copy), bg_mat_copy.mul(rest), out_mat);
// check with 'out_mat.empty()' and CV_8UC3.
if (!out_mat.empty()) out_mat.convertTo(out_mat, CV_8UC3);
}

void MNNMODNet::generate_matting(const std::map<std::string, MNN::Tensor *> &output_tensors,
const cv::Mat &mat, types::MattingContent &content,
bool remove_noise)
{
auto device_output_ptr = output_tensors.at("output"); // e.g (1,1,256,256)
MNN::Tensor host_output_tensor(device_output_ptr, device_output_ptr->getDimensionType());
device_output_ptr->copyToHostTensor(&host_output_tensor);
const unsigned int h = mat.rows;
const unsigned int w = mat.cols;

auto output_dims = host_output_tensor.shape();
const unsigned int out_h = output_dims.at(2);
const unsigned int out_w = output_dims.at(3);

float *output_ptr = host_output_tensor.host<float>();

cv::Mat alpha_pred(out_h, out_w, CV_32FC1, output_ptr);
// post process
if (remove_noise) this->remove_small_connected_area(alpha_pred);
// resize alpha
if (out_h != h || out_w != w)
{
cv::resize(alpha_pred, alpha_pred, cv::Size(w, h));
}
cv::Mat mat_copy;
mat.convertTo(mat_copy, CV_32FC3);
cv::Mat pmat = alpha_pred; // ref

std::vector<cv::Mat> mat_channels;
cv::split(mat_copy, mat_channels);
cv::Mat bmat = mat_channels.at(0);
cv::Mat gmat = mat_channels.at(1);
cv::Mat rmat = mat_channels.at(2); // ref only, zero-copy.
bmat = bmat.mul(pmat);
gmat = gmat.mul(pmat);
rmat = rmat.mul(pmat);
cv::Mat rest = 1.f - pmat;
cv::Mat mbmat = bmat.mul(pmat) + rest * 153.f;
cv::Mat mgmat = gmat.mul(pmat) + rest * 255.f;
cv::Mat mrmat = rmat.mul(pmat) + rest * 120.f;
std::vector<cv::Mat> fgr_channel_mats, merge_channel_mats;
fgr_channel_mats.push_back(bmat);
fgr_channel_mats.push_back(gmat);
fgr_channel_mats.push_back(rmat);
merge_channel_mats.push_back(mbmat);
merge_channel_mats.push_back(mgmat);
merge_channel_mats.push_back(mrmat);

content.pha_mat = pmat;
cv::merge(fgr_channel_mats, content.fgr_mat);
cv::merge(merge_channel_mats, content.merge_mat);

content.fgr_mat.convertTo(content.fgr_mat, CV_8UC3);
content.merge_mat.convertTo(content.merge_mat, CV_8UC3);

content.flag = true;
}

// https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L208
void MNNMODNet::remove_small_connected_area(cv::Mat &alpha_pred)
{
cv::Mat gray, binary;
alpha_pred.convertTo(gray, CV_8UC1, 255.f);
// 255 * 0.05 ~ 13
// https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L209
cv::threshold(gray, binary, 13, 255, cv::THRESH_BINARY);
// morphologyEx with OPEN operation to remove noise first.
auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3), cv::Point(-1, -1));
cv::morphologyEx(binary, binary, cv::MORPH_OPEN, kernel);
// Computationally connected domain
cv::Mat labels = cv::Mat::zeros(alpha_pred.size(), CV_32S);
cv::Mat stats, centroids;
int num_labels = cv::connectedComponentsWithStats(binary, labels, stats, centroids, 8, 4);
if (num_labels <= 1) return; // no noise, skip.
// find max connected area, 0 is background
int max_connected_id = 1; // 1,2,...
int max_connected_area = stats.at<int>(max_connected_id, cv::CC_STAT_AREA);
for (int i = 1; i < num_labels; ++i)
{
int tmp_connected_area = stats.at<int>(i, cv::CC_STAT_AREA);
if (tmp_connected_area > max_connected_area)
{
max_connected_area = tmp_connected_area;
max_connected_id = i;
}
}
const int h = alpha_pred.rows;
const int w = alpha_pred.cols;
// remove small connected area.
for (int i = 0; i < h; ++i)
{
int *label_row_ptr = labels.ptr<int>(i);
float *alpha_row_ptr = alpha_pred.ptr<float>(i);
for (int j = 0; j < w; ++j)
{
if (label_row_ptr[j] != max_connected_id)
alpha_row_ptr[j] = 0.f;
}
}
}
11 changes: 11 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
#include "lite/ort/cv/insectdet.h"
#include "lite/ort/cv/insectid.h"
#include "lite/ort/cv/plantid.h"
#include "lite/ort/cv/modnet.h"

#endif

Expand Down Expand Up @@ -169,6 +170,7 @@
#include "lite/mnn/cv/mnn_insectdet.h"
#include "lite/mnn/cv/mnn_insectid.h"
#include "lite/mnn/cv/mnn_plantid.h"
#include "lite/mnn/cv/mnn_modnet.h"

#endif

Expand Down Expand Up @@ -243,6 +245,7 @@
#include "lite/ncnn/cv/ncnn_pipnet19.h"
#include "lite/ncnn/cv/ncnn_insectid.h"
#include "lite/ncnn/cv/ncnn_plantid.h"
#include "lite/ncnn/cv/ncnn_modnet.h"

#endif

Expand Down Expand Up @@ -318,6 +321,7 @@
#include "lite/tnn/cv/tnn_insectdet.h"
#include "lite/tnn/cv/tnn_insectid.h"
#include "lite/tnn/cv/tnn_plantid.h"
#include "lite/tnn/cv/tnn_modnet.h"

#endif

Expand Down Expand Up @@ -416,6 +420,7 @@ namespace lite
typedef ortcv::InsectDet _InsectDet;
typedef ortcv::InsectID _InsectID;
typedef ortcv::PlantID _PlantID;
typedef ortcv::MODNet _MODNet;
#endif

// 1. classification
Expand Down Expand Up @@ -597,6 +602,7 @@ namespace lite
#ifdef BACKEND_ONNXRUNTIME
typedef _RobustVideoMatting RobustVideoMatting;
typedef _MGMatting MGMatting;
typedef _MODNet MODNet;
#endif
}
}
Expand Down Expand Up @@ -762,6 +768,7 @@ namespace lite
typedef ortcv::InsectDet _ONNXInsectDet;
typedef ortcv::InsectID _ONNXInsectID;
typedef ortcv::PlantID _ONNXPlantID;
typedef ortcv::MODNet _ONNXMODNet;

// 1. classification
namespace classification
Expand Down Expand Up @@ -915,6 +922,7 @@ namespace lite
{
typedef _ONNXRobustVideoMatting RobustVideoMatting;
typedef _ONNXMGMatting MGMatting;
typedef _ONNXMODNet MODNet;
}
}

Expand Down Expand Up @@ -1042,6 +1050,7 @@ namespace lite
{
typedef mnncv::MNNRobustVideoMatting RobustVideoMatting;
typedef mnncv::MNNMGMatting MGMatting;
typedef mnncv::MNNMODNet MODNet;
}

// style transfer
Expand Down Expand Up @@ -1190,6 +1199,7 @@ namespace lite
namespace matting
{
typedef ncnncv::NCNNRobustVideoMatting RobustVideoMatting;
typedef ncnncv::NCNNMODNet MODNet;
}
// style transfer
namespace style
Expand Down Expand Up @@ -1334,6 +1344,7 @@ namespace lite
{
typedef tnncv::TNNRobustVideoMatting RobustVideoMatting;
typedef tnncv::TNNMGMatting MGMatting;
typedef tnncv::TNNMODNet MODNet;
}
// style transfer
namespace style
Expand Down
Loading

0 comments on commit 6be4f5e

Please sign in to comment.