Skip to content
Snippets Groups Projects
Unverified Commit de7917d8 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Add cls weight in CrossEntropyLoss (#2797)

* Add cls weight in CrossEntropyLoss

* update unittes

* resolve comments

* fix unittest bug
parent 58c55a98
No related branches found
No related tags found
No related merge requests found
......@@ -6,9 +6,14 @@ from ..builder import LOSSES
from .utils import weight_reduce_loss
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
# element-wise losses
loss = F.cross_entropy(pred, label, reduction='none')
loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
# apply weights and do the reduction
if weight is not None:
......@@ -39,7 +44,8 @@ def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
avg_factor=None,
class_weight=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
......@@ -47,21 +53,27 @@ def binary_cross_entropy(pred,
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight, reduction='none')
pred, label.float(), weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None):
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None):
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='mean')[None]
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module()
......@@ -71,6 +83,7 @@ class CrossEntropyLoss(nn.Module):
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
......@@ -78,6 +91,7 @@ class CrossEntropyLoss(nn.Module):
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
......@@ -96,10 +110,15 @@ class CrossEntropyLoss(nn.Module):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
......
import pytest
import torch
def test_ce_loss():
from mmdet.models import build_loss
# use_mask and use_sigmoid cannot be true at the same time
with pytest.raises(AssertionError):
loss_cfg = dict(
type='CrossEntropyLoss',
use_mask=True,
use_sigmoid=True,
loss_weight=1.0)
build_loss(loss_cfg)
# test loss with class weights
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=[0.8, 0.2],
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment