diff --git a/configs/fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu.py b/configs/fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu.py index 41297fc0aa4dd0939839d7c9e6798d36f07d4b11..ac21fada1667ec4c8dc7153798ec8d01f02b1dc3 100644 --- a/configs/fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu.py +++ b/configs/fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu.py @@ -25,7 +25,16 @@ model = dict( in_channels=256, stacked_convs=4, feat_channels=256, - strides=[8, 16, 32, 64, 128])) + strides=[8, 16, 32, 64, 128], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) # training and testing settings train_cfg = dict( assigner=dict( @@ -34,9 +43,6 @@ train_cfg = dict( neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), - smoothl1_beta=0.11, - gamma=2.0, - alpha=0.25, allowed_border=-1, pos_weight=-1, debug=False) diff --git a/configs/fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x.py b/configs/fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x.py index 4f9352c2c9df4514f5b43074ef5d956b2d5a309c..d932bcfe2aabfa0c0808f2fe867b2981ae2677b4 100644 --- a/configs/fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x.py +++ b/configs/fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x.py @@ -26,7 +26,16 @@ model = dict( in_channels=256, stacked_convs=4, feat_channels=256, - strides=[8, 16, 32, 64, 128])) + strides=[8, 16, 32, 64, 128], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) # training and testing settings train_cfg = dict( assigner=dict( @@ -35,9 +44,6 @@ train_cfg = dict( neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), - smoothl1_beta=0.11, - gamma=2.0, - alpha=0.25, allowed_border=-1, pos_weight=-1, debug=False) diff --git a/configs/fcos/fcos_r50_caffe_fpn_gn_1x_4gpu.py b/configs/fcos/fcos_r50_caffe_fpn_gn_1x_4gpu.py index dd63ccfb2296d7d077aa8a35548f382eba71a560..6243c3645909bfe377525f4e4417538eae80f0dd 100644 --- a/configs/fcos/fcos_r50_caffe_fpn_gn_1x_4gpu.py +++ b/configs/fcos/fcos_r50_caffe_fpn_gn_1x_4gpu.py @@ -25,7 +25,16 @@ model = dict( in_channels=256, stacked_convs=4, feat_channels=256, - strides=[8, 16, 32, 64, 128])) + strides=[8, 16, 32, 64, 128], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) # training and testing settings train_cfg = dict( assigner=dict( @@ -34,9 +43,6 @@ train_cfg = dict( neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), - smoothl1_beta=0.11, - gamma=2.0, - alpha=0.25, allowed_border=-1, pos_weight=-1, debug=False) diff --git a/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x.py b/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x.py index 0b9f7254ed8bb229696a82830c50a318d606be1b..f78e1c1812ebf8843490aaeb05db68bcf464fda4 100644 --- a/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x.py +++ b/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x.py @@ -36,8 +36,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), diff --git a/configs/guided_anchoring/ga_faster_x101_32x4d_fpn_1x.py b/configs/guided_anchoring/ga_faster_x101_32x4d_fpn_1x.py index dabdf6c9864e1a9dad1e759165d76472618e78b0..61e7b99e667822bf5fd9be46152febd907d220fb 100644 --- a/configs/guided_anchoring/ga_faster_x101_32x4d_fpn_1x.py +++ b/configs/guided_anchoring/ga_faster_x101_32x4d_fpn_1x.py @@ -36,8 +36,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), diff --git a/configs/guided_anchoring/ga_retinanet_r50_caffe_fpn_1x.py b/configs/guided_anchoring/ga_retinanet_r50_caffe_fpn_1x.py index 63ba9e743f2921cffacf7331609e89eccc4bc8cc..ae6a18a9d452709ac3d700ed1c956e66fdc82828 100644 --- a/configs/guided_anchoring/ga_retinanet_r50_caffe_fpn_1x.py +++ b/configs/guided_anchoring/ga_retinanet_r50_caffe_fpn_1x.py @@ -40,8 +40,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='FocalLoss', use_sigmoid=True, diff --git a/configs/guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x.py b/configs/guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x.py index bd39bf12ccd2684b90b12afdaa202a1664bf9b8c..32f2bd620c044b61fe186e04cd9184478bf312c3 100644 --- a/configs/guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x.py +++ b/configs/guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x.py @@ -40,8 +40,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='FocalLoss', use_sigmoid=True, diff --git a/configs/guided_anchoring/ga_rpn_r101_caffe_rpn_1x.py b/configs/guided_anchoring/ga_rpn_r101_caffe_rpn_1x.py index d3acf87eacbb73c0e175af65ce138dbf27ca586c..c3d3b65654c08451708fb882f5b259bbca6b2802 100644 --- a/configs/guided_anchoring/ga_rpn_r101_caffe_rpn_1x.py +++ b/configs/guided_anchoring/ga_rpn_r101_caffe_rpn_1x.py @@ -36,8 +36,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))) diff --git a/configs/guided_anchoring/ga_rpn_r50_caffe_fpn_1x.py b/configs/guided_anchoring/ga_rpn_r50_caffe_fpn_1x.py index cea9b76d4321d23685c3dc9aebc1ccedacbfaf07..a4b6b6d624c663ce1b662e87440e5d9fa002c668 100644 --- a/configs/guided_anchoring/ga_rpn_r50_caffe_fpn_1x.py +++ b/configs/guided_anchoring/ga_rpn_r50_caffe_fpn_1x.py @@ -36,8 +36,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))) diff --git a/configs/guided_anchoring/ga_rpn_x101_32x4d_fpn_1x.py b/configs/guided_anchoring/ga_rpn_x101_32x4d_fpn_1x.py index c0372544bfed860a4a2a8578445ffef126af21a0..9eb1a69cc648388b0723dc1c33cb4380d3c4bc0f 100644 --- a/configs/guided_anchoring/ga_rpn_x101_32x4d_fpn_1x.py +++ b/configs/guided_anchoring/ga_rpn_x101_32x4d_fpn_1x.py @@ -36,8 +36,7 @@ model = dict( gamma=2.0, alpha=0.25, loss_weight=1.0), - loss_shape=dict( - type='IoULoss', style='bounded', beta=0.2, loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))) diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py index 645d5be29c039aeb2173525163b681675741d7ea..d118b14c2b0a364c3d3a5cce3c5ff060fcfc52f7 100644 --- a/mmdet/core/__init__.py +++ b/mmdet/core/__init__.py @@ -1,7 +1,6 @@ from .anchor import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403 from .mask import * # noqa: F401, F403 -from .loss import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 diff --git a/mmdet/core/loss/__init__.py b/mmdet/core/loss/__init__.py deleted file mode 100644 index ad7b21f92f423f67e61b296252be2842cd9b5f40..0000000000000000000000000000000000000000 --- a/mmdet/core/loss/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .losses import (weighted_nll_loss, weighted_cross_entropy, - weighted_binary_cross_entropy, sigmoid_focal_loss, - py_sigmoid_focal_loss, weighted_sigmoid_focal_loss, - mask_cross_entropy, smooth_l1_loss, weighted_smoothl1, - balanced_l1_loss, weighted_balanced_l1_loss, iou_loss, - bounded_iou_loss, weighted_iou_loss, accuracy) - -__all__ = [ - 'weighted_nll_loss', 'weighted_cross_entropy', - 'weighted_binary_cross_entropy', 'sigmoid_focal_loss', - 'py_sigmoid_focal_loss', 'weighted_sigmoid_focal_loss', - 'mask_cross_entropy', 'smooth_l1_loss', 'weighted_smoothl1', - 'balanced_l1_loss', 'weighted_balanced_l1_loss', 'bounded_iou_loss', - 'weighted_iou_loss', 'iou_loss', 'accuracy' -] diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py deleted file mode 100644 index 388b98f0f19d3d8e0ca431093e3649b2a1bf3f41..0000000000000000000000000000000000000000 --- a/mmdet/core/loss/losses.py +++ /dev/null @@ -1,261 +0,0 @@ -# TODO merge naive and weighted loss. -import numpy as np -import torch -import torch.nn.functional as F - -from ..bbox import bbox_overlaps -from ...ops import sigmoid_focal_loss - - -def weighted_nll_loss(pred, label, weight, avg_factor=None): - if avg_factor is None: - avg_factor = max(torch.sum(weight > 0).float().item(), 1.) - raw = F.nll_loss(pred, label, reduction='none') - return torch.sum(raw * weight)[None] / avg_factor - - -def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True): - if avg_factor is None: - avg_factor = max(torch.sum(weight > 0).float().item(), 1.) - raw = F.cross_entropy(pred, label, reduction='none') - if reduce: - return torch.sum(raw * weight)[None] / avg_factor - else: - return raw * weight / avg_factor - - -def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None): - if pred.dim() != label.dim(): - label, weight = _expand_binary_labels(label, weight, pred.size(-1)) - if avg_factor is None: - avg_factor = max(torch.sum(weight > 0).float().item(), 1.) - return F.binary_cross_entropy_with_logits( - pred, label.float(), weight.float(), - reduction='sum')[None] / avg_factor - - -def py_sigmoid_focal_loss(pred, - target, - weight, - gamma=2.0, - alpha=0.25, - reduction='mean'): - pred_sigmoid = pred.sigmoid() - target = target.type_as(pred) - pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) - weight = (alpha * target + (1 - alpha) * (1 - target)) * weight - weight = weight * pt.pow(gamma) - loss = F.binary_cross_entropy_with_logits( - pred, target, reduction='none') * weight - reduction_enum = F._Reduction.get_enum(reduction) - # none: 0, mean:1, sum: 2 - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.mean() - elif reduction_enum == 2: - return loss.sum() - - -def weighted_sigmoid_focal_loss(pred, - target, - weight, - gamma=2.0, - alpha=0.25, - avg_factor=None, - num_classes=80): - if avg_factor is None: - avg_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6 - return torch.sum( - sigmoid_focal_loss(pred, target, gamma, alpha, 'none') * - weight.view(-1, 1))[None] / avg_factor - - -def mask_cross_entropy(pred, target, label): - num_rois = pred.size()[0] - inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) - pred_slice = pred[inds, label].squeeze(1) - return F.binary_cross_entropy_with_logits( - pred_slice, target, reduction='mean')[None] - - -def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'): - assert beta > 0 - assert pred.size() == target.size() and target.numel() > 0 - diff = torch.abs(pred - target) - loss = torch.where(diff < beta, 0.5 * diff * diff / beta, - diff - 0.5 * beta) - reduction_enum = F._Reduction.get_enum(reduction) - # none: 0, mean:1, sum: 2 - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.sum() / pred.numel() - elif reduction_enum == 2: - return loss.sum() - - -def weighted_smoothl1(pred, target, weight, beta=1.0, avg_factor=None): - if avg_factor is None: - avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6 - loss = smooth_l1_loss(pred, target, beta, reduction='none') - return torch.sum(loss * weight)[None] / avg_factor - - -def balanced_l1_loss(pred, - target, - beta=1.0, - alpha=0.5, - gamma=1.5, - reduction='none'): - assert beta > 0 - assert pred.size() == target.size() and target.numel() > 0 - - diff = torch.abs(pred - target) - b = np.e**(gamma / alpha) - 1 - loss = torch.where( - diff < beta, alpha / b * - (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff, - gamma * diff + gamma / b - alpha * beta) - - reduction_enum = F._Reduction.get_enum(reduction) - # none: 0, elementwise_mean:1, sum: 2 - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.sum() / pred.numel() - elif reduction_enum == 2: - return loss.sum() - - return loss - - -def weighted_balanced_l1_loss(pred, - target, - weight, - beta=1.0, - alpha=0.5, - gamma=1.5, - avg_factor=None): - if avg_factor is None: - avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6 - loss = balanced_l1_loss(pred, target, beta, alpha, gamma, reduction='none') - return torch.sum(loss * weight)[None] / avg_factor - - -def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3, reduction='mean'): - """Improving Object Localization with Fitness NMS and Bounded IoU Loss, - https://arxiv.org/abs/1711.00164. - - Args: - pred (tensor): Predicted bboxes. - target (tensor): Target bboxes. - beta (float): beta parameter in smoothl1. - eps (float): eps to avoid NaN. - reduction (str): Reduction type. - """ - pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5 - pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5 - pred_w = pred[:, 2] - pred[:, 0] + 1 - pred_h = pred[:, 3] - pred[:, 1] + 1 - with torch.no_grad(): - target_ctrx = (target[:, 0] + target[:, 2]) * 0.5 - target_ctry = (target[:, 1] + target[:, 3]) * 0.5 - target_w = target[:, 2] - target[:, 0] + 1 - target_h = target[:, 3] - target[:, 1] + 1 - - dx = target_ctrx - pred_ctrx - dy = target_ctry - pred_ctry - - loss_dx = 1 - torch.max( - (target_w - 2 * dx.abs()) / - (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx)) - loss_dy = 1 - torch.max( - (target_h - 2 * dy.abs()) / - (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy)) - loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w / - (target_w + eps)) - loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h / - (target_h + eps)) - loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh], - dim=-1).view(loss_dx.size(0), -1) - - loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta, - loss_comb - 0.5 * beta) - reduction_enum = F._Reduction.get_enum(reduction) - # none: 0, mean:1, sum: 2 - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.sum() / pred.numel() - elif reduction_enum == 2: - return loss.sum() - - -def weighted_iou_loss(pred, - target, - weight, - style='naive', - beta=0.2, - eps=1e-3, - avg_factor=None): - if style not in ['bounded', 'naive']: - raise ValueError('Only support bounded iou loss and naive iou loss.') - inds = torch.nonzero(weight[:, 0] > 0) - if avg_factor is None: - avg_factor = inds.numel() + 1e-6 - - if inds.numel() > 0: - inds = inds.squeeze(1) - else: - return (pred * weight).sum()[None] / avg_factor - - if style == 'bounded': - loss = bounded_iou_loss( - pred[inds], target[inds], beta=beta, eps=eps, reduction='sum') - else: - loss = iou_loss(pred[inds], target[inds], eps=eps, reduction='sum') - loss = loss[None] / avg_factor - return loss - - -def accuracy(pred, target, topk=1): - if isinstance(topk, int): - topk = (topk, ) - return_single = True - else: - return_single = False - - maxk = max(topk) - _, pred_label = pred.topk(maxk, 1, True, True) - pred_label = pred_label.t() - correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / pred.size(0))) - return res[0] if return_single else res - - -def _expand_binary_labels(labels, label_weights, label_channels): - bin_labels = labels.new_full((labels.size(0), label_channels), 0) - inds = torch.nonzero(labels >= 1).squeeze() - if inds.numel() > 0: - bin_labels[inds, labels[inds] - 1] = 1 - bin_label_weights = label_weights.view(-1, 1).expand( - label_weights.size(0), label_channels) - return bin_labels, bin_label_weights - - -def iou_loss(pred_bboxes, target_bboxes, eps=1e-6, reduction='mean'): - ious = bbox_overlaps(pred_bboxes, target_bboxes, is_aligned=True) + eps - loss = -ious.log() - - reduction_enum = F._Reduction.get_enum(reduction) - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.mean() - elif reduction_enum == 2: - return loss.sum() diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py index f16eb3ca3963e08b228cf2930bb75e1415ca864f..a5ad9bcc6a949ac0beb7dc5ef8cff18b65888591 100644 --- a/mmdet/models/anchor_heads/fcos_head.py +++ b/mmdet/models/anchor_heads/fcos_head.py @@ -1,10 +1,9 @@ import torch import torch.nn as nn -import torch.nn.functional as F from mmcv.cnn import normal_init -from mmdet.core import (sigmoid_focal_loss, iou_loss, multi_apply, - multiclass_nms, distance2bbox) +from mmdet.core import multi_apply, multiclass_nms, distance2bbox +from ..builder import build_loss from ..registry import HEADS from ..utils import bias_init_with_prob, Scale, ConvModule @@ -22,6 +21,17 @@ class FCOSHead(nn.Module): strides=(4, 8, 16, 32, 64), regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF)), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), conv_cfg=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)): super(FCOSHead, self).__init__() @@ -33,6 +43,9 @@ class FCOSHead(nn.Module): self.stacked_convs = stacked_convs self.strides = strides self.regress_ranges = regress_ranges + self.loss_cls = build_loss(loss_cls) + self.loss_bbox = build_loss(loss_bbox) + self.loss_centerness = build_loss(loss_centerness) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -139,9 +152,9 @@ class FCOSHead(nn.Module): pos_inds = flatten_labels.nonzero().reshape(-1) num_pos = len(pos_inds) - loss_cls = sigmoid_focal_loss( - flatten_cls_scores, flatten_labels, cfg.gamma, cfg.alpha, - 'none').sum()[None] / (num_pos + num_imgs) # avoid num_pos is 0 + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, + avg_factor=num_pos + num_imgs) # avoid num_pos is 0 pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_bbox_targets = flatten_bbox_targets[pos_inds] @@ -154,20 +167,20 @@ class FCOSHead(nn.Module): pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets) # centerness weighted iou loss - loss_reg = ((iou_loss( + loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, - reduction='none') * pos_centerness_targets).sum() / - pos_centerness_targets.sum())[None] - loss_centerness = F.binary_cross_entropy_with_logits( - pos_centerness, pos_centerness_targets, reduction='mean')[None] + weight=pos_centerness_targets, + avg_factor=pos_centerness_targets.sum()) + loss_centerness = self.loss_centerness(pos_centerness, + pos_centerness_targets) else: - loss_reg = pos_bbox_preds.sum()[None] - loss_centerness = pos_centerness.sum()[None] + loss_bbox = pos_bbox_preds.sum() + loss_centerness = pos_centerness.sum() return dict( loss_cls=loss_cls, - loss_reg=loss_reg, + loss_bbox=loss_bbox, loss_centerness=loss_centerness) def get_bboxes(self, @@ -196,9 +209,10 @@ class FCOSHead(nn.Module): ] img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] - det_bboxes = self.get_bboxes_single( - cls_score_list, bbox_pred_list, centerness_pred_list, - mlvl_points, img_shape, scale_factor, cfg, rescale) + det_bboxes = self.get_bboxes_single(cls_score_list, bbox_pred_list, + centerness_pred_list, + mlvl_points, img_shape, + scale_factor, cfg, rescale) result_list.append(det_bboxes) return result_list diff --git a/mmdet/models/anchor_heads/guided_anchor_head.py b/mmdet/models/anchor_heads/guided_anchor_head.py index 9660f17708b83e0a18fb1e6e69585e636ae23391..8b5dc54ec88aea627a6ea24c0536c82d0b3fc966 100644 --- a/mmdet/models/anchor_heads/guided_anchor_head.py +++ b/mmdet/models/anchor_heads/guided_anchor_head.py @@ -115,8 +115,7 @@ class GuidedAnchorHead(AnchorHead): alpha=0.25, loss_weight=1.0), loss_shape=dict( - type='IoULoss', - style='bounded', + type='BoundedIoULoss', beta=0.2, loss_weight=1.0), loss_cls=dict( diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py index 92a3e9c5e1d6174b3a32f90b719206780f7dc2f8..c74a5988305bb291100bd71fce06c833aafd3517 100644 --- a/mmdet/models/anchor_heads/ssd_head.py +++ b/mmdet/models/anchor_heads/ssd_head.py @@ -4,9 +4,9 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import xavier_init -from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1, - multi_apply) +from mmdet.core import AnchorGenerator, anchor_target, multi_apply from .anchor_head import AnchorHead +from ..losses import smooth_l1_loss from ..registry import HEADS @@ -123,7 +123,7 @@ class SSDHead(AnchorHead): loss_cls_neg = topk_loss_cls_neg.sum() loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples - loss_bbox = weighted_smoothl1( + loss_bbox = smooth_l1_loss( bbox_pred, bbox_targets, bbox_weights, diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index d0d98ff9499e13ad8656d8fa4de1dc65a760b860..c67ea8a7fc75e48a1d4cd24b9c295b5590f80c98 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -2,8 +2,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmdet.core import delta2bbox, multiclass_nms, bbox_target, accuracy +from mmdet.core import delta2bbox, multiclass_nms, bbox_target from ..builder import build_loss +from ..losses import accuracy from ..registry import HEADS @@ -99,8 +100,9 @@ class BBoxHead(nn.Module): reduce=True): losses = dict() if cls_score is not None: + avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) losses['loss_cls'] = self.loss_cls( - cls_score, labels, label_weights, reduce=reduce) + cls_score, labels, label_weights, avg_factor=avg_factor) losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: pos_inds = labels > 0 diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py index 817f4d26da2567a2d31d2c3e4404dfc085d11969..45920dd3c209f1f756f0294bb1465a3faca61dcd 100644 --- a/mmdet/models/losses/__init__.py +++ b/mmdet/models/losses/__init__.py @@ -1,11 +1,18 @@ -from .cross_entropy_loss import CrossEntropyLoss -from .focal_loss import FocalLoss -from .smooth_l1_loss import SmoothL1Loss +from .accuracy import accuracy, Accuracy +from .cross_entropy_loss import (cross_entropy, binary_cross_entropy, + mask_cross_entropy, CrossEntropyLoss) +from .focal_loss import sigmoid_focal_loss, FocalLoss +from .smooth_l1_loss import smooth_l1_loss, SmoothL1Loss from .ghm_loss import GHMC, GHMR -from .balanced_l1_loss import BalancedL1Loss -from .iou_loss import IoULoss +from .balanced_l1_loss import balanced_l1_loss, BalancedL1Loss +from .iou_loss import iou_loss, bounded_iou_loss, IoULoss, BoundedIoULoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ - 'CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'BalancedL1Loss', - 'IoULoss', 'GHMC', 'GHMR' + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss', + 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss', + 'BalancedL1Loss', 'iou_loss', 'bounded_iou_loss', 'IoULoss', + 'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss', 'weight_reduce_loss', + 'weighted_loss' ] diff --git a/mmdet/models/losses/accuracy.py b/mmdet/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..20d0ad8cd3dbdd4a3707818daf9ffbd6e1a8a748 --- /dev/null +++ b/mmdet/models/losses/accuracy.py @@ -0,0 +1,31 @@ +import torch.nn as nn + + +def accuracy(pred, target, topk=1): + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + _, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / pred.size(0))) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + + def __init__(self, topk=(1, )): + super().__init__() + self.topk = topk + + def forward(self, pred, target): + return accuracy(pred, target, self.topk) diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py index 7511e2668d5b8c8f8f24dda7d69c5b6732b7ab19..dccb9e5b89ff86711729890cd6c58d31f0836d9b 100644 --- a/mmdet/models/losses/balanced_l1_loss.py +++ b/mmdet/models/losses/balanced_l1_loss.py @@ -1,9 +1,33 @@ +import numpy as np +import torch import torch.nn as nn -from mmdet.core import weighted_balanced_l1_loss +from .utils import reduce_loss, weighted_loss from ..registry import LOSSES +@weighted_loss +def balanced_l1_loss(pred, + target, + beta=1.0, + alpha=0.5, + gamma=1.5, + reduction='mean'): + assert beta > 0 + assert pred.size() == target.size() and target.numel() > 0 + + diff = torch.abs(pred - target) + b = np.e**(gamma / alpha) - 1 + loss = torch.where( + diff < beta, alpha / b * + (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff, + gamma * diff + gamma / b - alpha * beta) + + loss = reduce_loss(loss, reduction) + + return loss + + @LOSSES.register_module class BalancedL1Loss(nn.Module): """Balanced L1 Loss @@ -11,21 +35,28 @@ class BalancedL1Loss(nn.Module): arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) """ - def __init__(self, alpha=0.5, gamma=1.5, beta=1.0, loss_weight=1.0): + def __init__(self, + alpha=0.5, + gamma=1.5, + beta=1.0, + reduction='mean', + loss_weight=1.0): super(BalancedL1Loss, self).__init__() self.alpha = alpha self.gamma = gamma self.beta = beta + self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight, *args, **kwargs): - loss_bbox = self.loss_weight * weighted_balanced_l1_loss( + def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + loss_bbox = self.loss_weight * balanced_l1_loss( pred, target, weight, alpha=self.alpha, gamma=self.gamma, beta=self.beta, - *args, + reduction=self.reduction, + avg_factor=avg_factor, **kwargs) return loss_bbox diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 19539873c748536267b52d4c49156b299a9aad3e..1921978defd939248d3faa3fabc2b65f37316fa3 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -1,28 +1,85 @@ +import torch import torch.nn as nn -from mmdet.core import (weighted_cross_entropy, weighted_binary_cross_entropy, - mask_cross_entropy) +import torch.nn.functional as F +from .utils import weight_reduce_loss, weighted_loss from ..registry import LOSSES +cross_entropy = weighted_loss(F.cross_entropy) + + +def _expand_binary_labels(labels, label_weights, label_channels): + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + inds = torch.nonzero(labels >= 1).squeeze() + if inds.numel() > 0: + bin_labels[inds, labels[inds] - 1] = 1 + if label_weights is None: + bin_label_weights = None + else: + bin_label_weights = label_weights.view(-1, 1).expand( + label_weights.size(0), label_channels) + return bin_labels, bin_label_weights + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None): + if pred.dim() != label.dim(): + label, weight = _expand_binary_labels(label, weight, pred.size(-1)) + + # element-wise losses + if weight is not None: + weight = weight.float() + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), weight, reduction='none') + # apply weights and do the reduction + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None): + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, reduction='mean')[None] + @LOSSES.register_module class CrossEntropyLoss(nn.Module): - def __init__(self, use_sigmoid=False, use_mask=False, loss_weight=1.0): + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + loss_weight=1.0): super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid self.use_mask = use_mask + self.reduction = reduction self.loss_weight = loss_weight if self.use_sigmoid: - self.cls_criterion = weighted_binary_cross_entropy + self.cls_criterion = binary_cross_entropy elif self.use_mask: self.cls_criterion = mask_cross_entropy else: - self.cls_criterion = weighted_cross_entropy + self.cls_criterion = cross_entropy - def forward(self, cls_score, label, label_weight, *args, **kwargs): + def forward(self, cls_score, label, weight=None, avg_factor=None, + **kwargs): loss_cls = self.loss_weight * self.cls_criterion( - cls_score, label, label_weight, *args, **kwargs) + cls_score, + label, + weight, + reduction=self.reduction, + avg_factor=avg_factor, + **kwargs) return loss_cls diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py index 12b17df082d81acae2ccf9cd6c8211b762551069..b8ccfa07d8b70744a51b3666ef1a84a7eb4ba647 100644 --- a/mmdet/models/losses/focal_loss.py +++ b/mmdet/models/losses/focal_loss.py @@ -1,35 +1,74 @@ import torch.nn as nn -from mmdet.core import weighted_sigmoid_focal_loss +import torch.nn.functional as F +from mmdet.ops import sigmoid_focal_loss as _sigmoid_focal_loss +from .utils import weight_reduce_loss from ..registry import LOSSES +# This method is only for debugging +def py_sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + loss = _sigmoid_focal_loss(pred, target, gamma, alpha) + # TODO: find a proper way to handle the shape of weight + if weight is not None: + weight = weight.view(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + @LOSSES.register_module class FocalLoss(nn.Module): def __init__(self, - use_sigmoid=False, - loss_weight=1.0, + use_sigmoid=True, gamma=2.0, - alpha=0.25): + alpha=0.25, + reduction='mean', + loss_weight=1.0): super(FocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' self.use_sigmoid = use_sigmoid - self.loss_weight = loss_weight self.gamma = gamma self.alpha = alpha - self.cls_criterion = weighted_sigmoid_focal_loss + self.reduction = reduction + self.loss_weight = loss_weight - def forward(self, cls_score, label, label_weight, *args, **kwargs): + def forward(self, pred, target, weight=None, avg_factor=None): if self.use_sigmoid: - loss_cls = self.loss_weight * self.cls_criterion( - cls_score, - label, - label_weight, + loss_cls = self.loss_weight * sigmoid_focal_loss( + pred, + target, + weight, gamma=self.gamma, alpha=self.alpha, - *args, - **kwargs) + reduction=self.reduction, + avg_factor=avg_factor) else: raise NotImplementedError return loss_cls diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py index 681cd0ce9054c4ab2e995d13cfc0f2ab28018ca7..7beeb477b92df9df33449d12d20eb8d79624bc1f 100644 --- a/mmdet/models/losses/ghm_loss.py +++ b/mmdet/models/losses/ghm_loss.py @@ -29,12 +29,8 @@ class GHMC(nn.Module): use_sigmoid (bool): Can only be true for BCE based loss now. loss_weight (float): The weight of the total GHM-C loss. """ - def __init__( - self, - bins=10, - momentum=0, - use_sigmoid=True, - loss_weight=1.0): + + def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0): super(GHMC, self).__init__() self.bins = bins self.momentum = momentum @@ -63,7 +59,7 @@ class GHMC(nn.Module): # the target should be binary class label if pred.dim() != target.dim(): target, label_weight = _expand_binary_labels( - target, label_weight, pred.size(-1)) + target, label_weight, pred.size(-1)) target, label_weight = target.float(), label_weight.float() edges = self.edges mmt = self.momentum @@ -76,7 +72,7 @@ class GHMC(nn.Module): tot = max(valid.float().sum().item(), 1.0) n = 0 # n valid bins for i in range(self.bins): - inds = (g >= edges[i]) & (g < edges[i+1]) & valid + inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: if mmt > 0: @@ -108,12 +104,8 @@ class GHMR(nn.Module): momentum (float): The parameter for moving average. loss_weight (float): The weight of the total GHM-R loss. """ - def __init__( - self, - mu=0.02, - bins=10, - momentum=0, - loss_weight=1.0): + + def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0): super(GHMR, self).__init__() self.mu = mu self.bins = bins @@ -154,7 +146,7 @@ class GHMR(nn.Module): tot = max(label_weight.float().sum().item(), 1.0) n = 0 # n: valid bins for i in range(self.bins): - inds = (g >= edges[i]) & (g < edges[i+1]) & valid + inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: n += 1 diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py index edff11b0eb1b0f71d7a49278b209fe36c0bae3f9..7c235cd82ad0e2471107293482714026544c9727 100644 --- a/mmdet/models/losses/iou_loss.py +++ b/mmdet/models/losses/iou_loss.py @@ -1,27 +1,117 @@ +import torch import torch.nn as nn -from mmdet.core import weighted_iou_loss +from mmdet.core import bbox_overlaps +from .utils import weighted_loss from ..registry import LOSSES +@weighted_loss +def iou_loss(pred, target, eps=1e-6): + """IoU loss. + + Computing the IoU loss between a set of predicted bboxes and target bboxes. + The loss is calculated as negative log of IoU. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Eps to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps) + loss = -ious.log() + return loss + + +@weighted_loss +def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3): + """Improving Object Localization with Fitness NMS and Bounded IoU Loss, + https://arxiv.org/abs/1711.00164. + + Args: + pred (tensor): Predicted bboxes. + target (tensor): Target bboxes. + beta (float): beta parameter in smoothl1. + eps (float): eps to avoid NaN. + """ + pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5 + pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5 + pred_w = pred[:, 2] - pred[:, 0] + 1 + pred_h = pred[:, 3] - pred[:, 1] + 1 + with torch.no_grad(): + target_ctrx = (target[:, 0] + target[:, 2]) * 0.5 + target_ctry = (target[:, 1] + target[:, 3]) * 0.5 + target_w = target[:, 2] - target[:, 0] + 1 + target_h = target[:, 3] - target[:, 1] + 1 + + dx = target_ctrx - pred_ctrx + dy = target_ctry - pred_ctry + + loss_dx = 1 - torch.max( + (target_w - 2 * dx.abs()) / + (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx)) + loss_dy = 1 - torch.max( + (target_h - 2 * dy.abs()) / + (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy)) + loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w / + (target_w + eps)) + loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h / + (target_h + eps)) + loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh], + dim=-1).view(loss_dx.size(0), -1) + + loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta, + loss_comb - 0.5 * beta) + return loss + + @LOSSES.register_module class IoULoss(nn.Module): - def __init__(self, style='naive', beta=0.2, eps=1e-3, loss_weight=1.0): + def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): super(IoULoss, self).__init__() - self.style = style + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + if weight is not None and not torch.any(weight > 0): + return (pred * weight).sum() # 0 + loss = self.loss_weight * iou_loss( + pred, + target, + weight, + eps=self.eps, + reduction=self.reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@LOSSES.register_module +class BoundedIoULoss(nn.Module): + + def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0): + super(BoundedIoULoss, self).__init__() self.beta = beta self.eps = eps + self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight, *args, **kwargs): - loss = self.loss_weight * weighted_iou_loss( + def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + if weight is not None and not torch.any(weight > 0): + return (pred * weight).sum() # 0 + loss = self.loss_weight * bounded_iou_loss( pred, target, weight, - style=self.style, beta=self.beta, eps=self.eps, - *args, + reduction=self.reduction, + avg_factor=avg_factor, **kwargs) return loss diff --git a/mmdet/models/losses/smooth_l1_loss.py b/mmdet/models/losses/smooth_l1_loss.py index 7c11aa50f1e72d448a5838eb3d65a892dda56f67..6a098fcd72cf3b39d941b1983795c82d0f9c4fe4 100644 --- a/mmdet/models/losses/smooth_l1_loss.py +++ b/mmdet/models/losses/smooth_l1_loss.py @@ -1,18 +1,36 @@ +import torch import torch.nn as nn -from mmdet.core import weighted_smoothl1 +from .utils import weighted_loss from ..registry import LOSSES +@weighted_loss +def smooth_l1_loss(pred, target, beta=1.0): + assert beta > 0 + assert pred.size() == target.size() and target.numel() > 0 + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff * diff / beta, + diff - 0.5 * beta) + return loss + + @LOSSES.register_module class SmoothL1Loss(nn.Module): - def __init__(self, beta=1.0, loss_weight=1.0): + def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0): super(SmoothL1Loss, self).__init__() self.beta = beta + self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight, *args, **kwargs): - loss_bbox = self.loss_weight * weighted_smoothl1( - pred, target, weight, beta=self.beta, *args, **kwargs) + def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + loss_bbox = self.loss_weight * smooth_l1_loss( + pred, + target, + weight, + beta=self.beta, + reduction=self.reduction, + avg_factor=avg_factor, + **kwargs) return loss_bbox diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b902c64b8d8a258d3dd13f364515c640b2d2e838 --- /dev/null +++ b/mmdet/models/losses/utils.py @@ -0,0 +1,96 @@ +import functools + +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Avarage factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + # otherwise average the loss by avg_factor + else: + if reduction != 'mean': + raise ValueError( + 'avg_factor can only be used with reduction="mean"') + loss = loss.sum() / avg_factor + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py index 803df4153a4c5d06d7054bd8fcd463f8e1febcd0..e690f76305ba6a8a6aeb946664e7ddc7f72fcf6a 100644 --- a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py +++ b/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py @@ -1,4 +1,3 @@ -import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable @@ -8,7 +7,7 @@ from .. import sigmoid_focal_loss_cuda class SigmoidFocalLossFunction(Function): @staticmethod - def forward(ctx, input, target, gamma=2.0, alpha=0.25, reduction='mean'): + def forward(ctx, input, target, gamma=2.0, alpha=0.25): ctx.save_for_backward(input, target) num_classes = input.shape[1] ctx.num_classes = num_classes @@ -17,14 +16,7 @@ class SigmoidFocalLossFunction(Function): loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes, gamma, alpha) - reduction_enum = F._Reduction.get_enum(reduction) - # none: 0, mean:1, sum: 2 - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.mean() - elif reduction_enum == 2: - return loss.sum() + return loss @staticmethod @once_differentiable diff --git a/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py index 3caff39910623181409468af1ca61c2006cf876e..34202b566437a4d4c6fee5b0cf70f630cac29b3f 100644 --- a/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py +++ b/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py @@ -3,6 +3,7 @@ from torch import nn from ..functions.sigmoid_focal_loss import sigmoid_focal_loss +# TODO: remove this module class SigmoidFocalLoss(nn.Module): def __init__(self, gamma, alpha):