Skip to content

Commit

Permalink
Add reduction_override flag (#839)
Browse files Browse the repository at this point in the history
* add reduction_override flag

* change default value of reduction_override as None

* add assertion, fix format

* delete redudant statement in util

* delete redudant comment
  • Loading branch information
yhcao6 authored and hellock committed Jun 20, 2019
1 parent fc0172b commit 32df98e
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 19 deletions.
2 changes: 1 addition & 1 deletion mmdet/core/bbox/samplers/ohem_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def hard_mining(self, inds, num_expected, bboxes, labels, feats):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
reduction_override='none')['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]

Expand Down
11 changes: 8 additions & 3 deletions mmdet/models/bbox_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,16 @@ def loss(self,
label_weights,
bbox_targets,
bbox_weights,
reduce=True):
reduction_override=None):
losses = dict()
if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
losses['loss_cls'] = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
cls_score,
labels,
label_weights,
avg_factor=avg_factor,
reduction_override=reduction_override)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
pos_inds = labels > 0
Expand All @@ -115,7 +119,8 @@ def loss(self,
pos_bbox_pred,
bbox_targets[pos_inds],
bbox_weights[pos_inds],
avg_factor=bbox_targets.size(0))
avg_factor=bbox_targets.size(0),
reduction_override=reduction_override)
return losses

def get_det_bboxes(self,
Expand Down
13 changes: 11 additions & 2 deletions mmdet/models/losses/balanced_l1_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,24 @@ def __init__(self,
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * balanced_l1_loss(
pred,
target,
weight,
alpha=self.alpha,
gamma=self.gamma,
beta=self.beta,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox
12 changes: 10 additions & 2 deletions mmdet/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,21 @@ def __init__(self,
else:
self.cls_criterion = cross_entropy

def forward(self, cls_score, label, weight=None, avg_factor=None,
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
12 changes: 10 additions & 2 deletions mmdet/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,23 @@ def __init__(self,
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self, pred, target, weight=None, avg_factor=None):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
loss_cls = self.loss_weight * sigmoid_focal_loss(
pred,
target,
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
Expand Down
3 changes: 3 additions & 0 deletions mmdet/models/losses/ghm_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _expand_binary_labels(labels, label_weights, label_channels):
return bin_labels, bin_label_weights


# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module
class GHMC(nn.Module):
"""GHM Classification Loss.
Expand Down Expand Up @@ -90,6 +91,7 @@ def forward(self, pred, target, label_weight, *args, **kwargs):
return loss * self.loss_weight


# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module
class GHMR(nn.Module):
"""GHM Regression Loss.
Expand All @@ -116,6 +118,7 @@ def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0):
self.acc_sum = torch.zeros(bins).cuda()
self.loss_weight = loss_weight

# TODO: support reduction parameter
def forward(self, pred, target, label_weight, avg_factor=None):
"""Calculate the GHM-R loss.
Expand Down
13 changes: 11 additions & 2 deletions mmdet/models/losses/iou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,24 @@ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
if weight is not None and not torch.any(weight > 0):
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * iou_loss(
pred,
target,
weight,
eps=self.eps,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
Expand Down
13 changes: 11 additions & 2 deletions mmdet/models/losses/smooth_l1_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * smooth_l1_loss(
pred,
target,
weight,
beta=self.beta,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox
11 changes: 6 additions & 5 deletions mmdet/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
# otherwise average the loss by avg_factor
else:
if reduction != 'mean':
raise ValueError(
'avg_factor can only be used with reduction="mean"')
loss = loss.sum() / avg_factor
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss


Expand Down

0 comments on commit 32df98e

Please sign in to comment.