Skip to content

Commit

Permalink
Fix bug of ce loss when reduction != mean (#848)
Browse files Browse the repository at this point in the history
* fix bug of ce loss when reduction != mean

* change function order

* modify comment

* minor fix
  • Loading branch information
yhcao6 authored and hellock committed Jun 21, 2019
1 parent f724f9a commit 4a0d7ad
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions mmdet/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@
import torch.nn as nn
import torch.nn.functional as F

from .utils import weight_reduce_loss, weighted_loss
from .utils import weight_reduce_loss
from ..registry import LOSSES

cross_entropy = weighted_loss(F.cross_entropy)

def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
# element-wise losses
loss = F.cross_entropy(pred, label, reduction='none')

# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

return loss


def _expand_binary_labels(labels, label_weights, label_channels):
Expand Down

0 comments on commit 4a0d7ad

Please sign in to comment.