Skip to content

Commit

Permalink
Added mask visualization part to inference part and add out_file inte…
Browse files Browse the repository at this point in the history
…rface. (#403)

* Update README.md

* Update inference.py

* Update README.md

* Update inference.py

Added mask visualization part for inferring.

* Update README.md

* Update inference.py

* Update inference.py

convert all tabs to spaces

* Update inference.py
  • Loading branch information
Luodian authored and hellock committed Mar 24, 2019
1 parent a3c8ddf commit edb0393
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
import torch

from mmdet.core import get_classes
from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform
from mmdet.core import get_classes


def _prepare_data(img, img_transform, cfg, device):
Expand Down Expand Up @@ -50,18 +51,33 @@ def inference_detector(model, imgs, cfg, device='cuda:0'):
return _inference_generator(model, imgs, img_transform, cfg, device)


def show_result(img, result, dataset='coco', score_thr=0.3):
def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None):
img = mmcv.imread(img)
class_names = get_classes(dataset)
if isinstance(result, tuple):
bbox_result, segm_result = result
else:
bbox_result, segm_result = result, None
bboxes = np.vstack(bbox_result)
# draw segmentation masks
if segm_result is not None:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
for i in inds:
color_mask = np.random.randint(
0, 256, (1, 3), dtype=np.uint8)
mask = maskUtils.decode(segms[i]).astype(np.bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5
# draw bounding boxes
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
img = mmcv.imread(img)
mmcv.imshow_det_bboxes(
img.copy(),
bboxes,
labels,
class_names=class_names,
score_thr=score_thr)
score_thr=score_thr,
show=out_file is None)

0 comments on commit edb0393

Please sign in to comment.