From 32df98e970c8848d3d9fd72492aba8bc030377d8 Mon Sep 17 00:00:00 2001 From: Cao Yuhang Date: Thu, 20 Jun 2019 22:25:42 +0800 Subject: [PATCH] Add reduction_override flag (#839) * add reduction_override flag * change default value of reduction_override as None * add assertion, fix format * delete redudant statement in util * delete redudant comment --- mmdet/core/bbox/samplers/ohem_sampler.py | 2 +- mmdet/models/bbox_heads/bbox_head.py | 11 ++++++++--- mmdet/models/losses/balanced_l1_loss.py | 13 +++++++++++-- mmdet/models/losses/cross_entropy_loss.py | 12 ++++++++++-- mmdet/models/losses/focal_loss.py | 12 ++++++++++-- mmdet/models/losses/ghm_loss.py | 3 +++ mmdet/models/losses/iou_loss.py | 13 +++++++++++-- mmdet/models/losses/smooth_l1_loss.py | 13 +++++++++++-- mmdet/models/losses/utils.py | 11 ++++++----- 9 files changed, 71 insertions(+), 19 deletions(-) diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py index 800a1c2..0711d97 100644 --- a/mmdet/core/bbox/samplers/ohem_sampler.py +++ b/mmdet/core/bbox/samplers/ohem_sampler.py @@ -36,7 +36,7 @@ def hard_mining(self, inds, num_expected, bboxes, labels, feats): label_weights=cls_score.new_ones(cls_score.size(0)), bbox_targets=None, bbox_weights=None, - reduce=False)['loss_cls'] + reduction_override='none')['loss_cls'] _, topk_loss_inds = loss.topk(num_expected) return inds[topk_loss_inds] diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index c67ea8a..436592c 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -97,12 +97,16 @@ def loss(self, label_weights, bbox_targets, bbox_weights, - reduce=True): + reduction_override=None): losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) losses['loss_cls'] = self.loss_cls( - cls_score, labels, label_weights, avg_factor=avg_factor) + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: pos_inds = labels > 0 @@ -115,7 +119,8 @@ def loss(self, pos_bbox_pred, bbox_targets[pos_inds], bbox_weights[pos_inds], - avg_factor=bbox_targets.size(0)) + avg_factor=bbox_targets.size(0), + reduction_override=reduction_override) return losses def get_det_bboxes(self, diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py index 2dee674..8593396 100644 --- a/mmdet/models/losses/balanced_l1_loss.py +++ b/mmdet/models/losses/balanced_l1_loss.py @@ -46,7 +46,16 @@ def __init__(self, self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) loss_bbox = self.loss_weight * balanced_l1_loss( pred, target, @@ -54,7 +63,7 @@ def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): alpha=self.alpha, gamma=self.gamma, beta=self.beta, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_bbox diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 1921978..2f2ce69 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -73,13 +73,21 @@ def __init__(self, else: self.cls_criterion = cross_entropy - def forward(self, cls_score, label, weight=None, avg_factor=None, + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_cls diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py index b8ccfa0..7a46356 100644 --- a/mmdet/models/losses/focal_loss.py +++ b/mmdet/models/losses/focal_loss.py @@ -59,7 +59,15 @@ def __init__(self, self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight=None, avg_factor=None): + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) if self.use_sigmoid: loss_cls = self.loss_weight * sigmoid_focal_loss( pred, @@ -67,7 +75,7 @@ def forward(self, pred, target, weight=None, avg_factor=None): weight, gamma=self.gamma, alpha=self.alpha, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor) else: raise NotImplementedError diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py index 7beeb47..95656a2 100644 --- a/mmdet/models/losses/ghm_loss.py +++ b/mmdet/models/losses/ghm_loss.py @@ -15,6 +15,7 @@ def _expand_binary_labels(labels, label_weights, label_channels): return bin_labels, bin_label_weights +# TODO: code refactoring to make it consistent with other losses @LOSSES.register_module class GHMC(nn.Module): """GHM Classification Loss. @@ -90,6 +91,7 @@ def forward(self, pred, target, label_weight, *args, **kwargs): return loss * self.loss_weight +# TODO: code refactoring to make it consistent with other losses @LOSSES.register_module class GHMR(nn.Module): """GHM Regression Loss. @@ -116,6 +118,7 @@ def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0): self.acc_sum = torch.zeros(bins).cuda() self.loss_weight = loss_weight + # TODO: support reduction parameter def forward(self, pred, target, label_weight, avg_factor=None): """Calculate the GHM-R loss. diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py index 7c235cd..967e576 100644 --- a/mmdet/models/losses/iou_loss.py +++ b/mmdet/models/losses/iou_loss.py @@ -78,15 +78,24 @@ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): if weight is not None and not torch.any(weight > 0): return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) loss = self.loss_weight * iou_loss( pred, target, weight, eps=self.eps, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor, **kwargs) return loss diff --git a/mmdet/models/losses/smooth_l1_loss.py b/mmdet/models/losses/smooth_l1_loss.py index 6a098fc..75d71e8 100644 --- a/mmdet/models/losses/smooth_l1_loss.py +++ b/mmdet/models/losses/smooth_l1_loss.py @@ -24,13 +24,22 @@ def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0): self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) loss_bbox = self.loss_weight * smooth_l1_loss( pred, target, weight, beta=self.beta, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_bbox diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py index b902c64..5c16e06 100644 --- a/mmdet/models/losses/utils.py +++ b/mmdet/models/losses/utils.py @@ -42,12 +42,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): # if avg_factor is not specified, just reduce the loss if avg_factor is None: loss = reduce_loss(loss, reduction) - # otherwise average the loss by avg_factor else: - if reduction != 'mean': - raise ValueError( - 'avg_factor can only be used with reduction="mean"') - loss = loss.sum() / avg_factor + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') return loss