From d67a2e16e4b6c91a3e15f5879e13c6b2e6f02423 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Sun, 14 Apr 2019 13:40:59 +0800
Subject: [PATCH] add cpp sigmoid focal loss (#490)

* add cpp sigmoid focal loss

* modify interface

* format cpp code, support, fix pep8 error

* format cpp code as google style
---
 compile.sh                                    |   7 +
 mmdet/core/anchor/anchor_target.py            |   3 -
 mmdet/core/loss/__init__.py                   |  12 +-
 mmdet/core/loss/losses.py                     |  20 ++-
 mmdet/models/anchor_heads/anchor_head.py      |   8 +-
 mmdet/ops/__init__.py                         |   3 +-
 mmdet/ops/sigmoid_focal_loss/__init__.py      |   3 +
 .../sigmoid_focal_loss/functions/__init__.py  |   0
 .../functions/sigmoid_focal_loss.py           |  42 +++++
 .../sigmoid_focal_loss/modules/__init__.py    |   0
 .../modules/sigmoid_focal_loss.py             |  23 +++
 mmdet/ops/sigmoid_focal_loss/setup.py         |  12 ++
 .../src/sigmoid_focal_loss.cpp                |  43 +++++
 .../src/sigmoid_focal_loss_cuda.cu            | 169 ++++++++++++++++++
 14 files changed, 320 insertions(+), 25 deletions(-)
 create mode 100644 mmdet/ops/sigmoid_focal_loss/__init__.py
 create mode 100644 mmdet/ops/sigmoid_focal_loss/functions/__init__.py
 create mode 100644 mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py
 create mode 100644 mmdet/ops/sigmoid_focal_loss/modules/__init__.py
 create mode 100644 mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py
 create mode 100644 mmdet/ops/sigmoid_focal_loss/setup.py
 create mode 100644 mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp
 create mode 100644 mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu

diff --git a/compile.sh b/compile.sh
index 9ae7d043..335cf51d 100755
--- a/compile.sh
+++ b/compile.sh
@@ -29,3 +29,10 @@ if [ -d "build" ]; then
     rm -r build
 fi
 $PYTHON setup.py build_ext --inplace
+
+echo "Building sigmoid focal loss op..."
+cd ../sigmoid_focal_loss
+if [ -d "build" ]; then
+    rm -r build
+fi
+$PYTHON setup.py build_ext --inplace
diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py
index 7a5bf4ec..26489771 100644
--- a/mmdet/core/anchor/anchor_target.py
+++ b/mmdet/core/anchor/anchor_target.py
@@ -152,9 +152,6 @@ def anchor_target_single(flat_anchors,
         num_total_anchors = flat_anchors.size(0)
         labels = unmap(labels, num_total_anchors, inside_flags)
         label_weights = unmap(label_weights, num_total_anchors, inside_flags)
-        if label_channels > 1:
-            labels, label_weights = expand_binary_labels(
-                labels, label_weights, label_channels)
         bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
         bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
 
diff --git a/mmdet/core/loss/__init__.py b/mmdet/core/loss/__init__.py
index 661f0d64..477906e5 100644
--- a/mmdet/core/loss/__init__.py
+++ b/mmdet/core/loss/__init__.py
@@ -1,11 +1,11 @@
-from .losses import (weighted_nll_loss, weighted_cross_entropy,
-                     weighted_binary_cross_entropy, sigmoid_focal_loss,
-                     weighted_sigmoid_focal_loss, mask_cross_entropy,
-                     smooth_l1_loss, weighted_smoothl1, accuracy)
+from .losses import (
+    weighted_nll_loss, weighted_cross_entropy, weighted_binary_cross_entropy,
+    sigmoid_focal_loss, py_sigmoid_focal_loss, weighted_sigmoid_focal_loss,
+    mask_cross_entropy, smooth_l1_loss, weighted_smoothl1, accuracy)
 
 __all__ = [
     'weighted_nll_loss', 'weighted_cross_entropy',
     'weighted_binary_cross_entropy', 'sigmoid_focal_loss',
-    'weighted_sigmoid_focal_loss', 'mask_cross_entropy', 'smooth_l1_loss',
-    'weighted_smoothl1', 'accuracy'
+    'py_sigmoid_focal_loss', 'weighted_sigmoid_focal_loss',
+    'mask_cross_entropy', 'smooth_l1_loss', 'weighted_smoothl1', 'accuracy'
 ]
diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py
index 560dac43..922e058a 100644
--- a/mmdet/core/loss/losses.py
+++ b/mmdet/core/loss/losses.py
@@ -2,6 +2,8 @@
 import torch
 import torch.nn.functional as F
 
+from ...ops import sigmoid_focal_loss
+
 
 def weighted_nll_loss(pred, label, weight, avg_factor=None):
     if avg_factor is None:
@@ -28,12 +30,12 @@ def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
         reduction='sum')[None] / avg_factor
 
 
-def sigmoid_focal_loss(pred,
-                       target,
-                       weight,
-                       gamma=2.0,
-                       alpha=0.25,
-                       reduction='mean'):
+def py_sigmoid_focal_loss(pred,
+                          target,
+                          weight,
+                          gamma=2.0,
+                          alpha=0.25,
+                          reduction='mean'):
     pred_sigmoid = pred.sigmoid()
     target = target.type_as(pred)
     pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
@@ -60,9 +62,9 @@ def weighted_sigmoid_focal_loss(pred,
                                 num_classes=80):
     if avg_factor is None:
         avg_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6
-    return sigmoid_focal_loss(
-        pred, target, weight, gamma=gamma, alpha=alpha,
-        reduction='sum')[None] / avg_factor
+    return torch.sum(
+        sigmoid_focal_loss(pred, target, gamma, alpha, 'none') * weight.view(
+            -1, 1))[None] / avg_factor
 
 
 def mask_cross_entropy(pred, target, label):
diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index d3ab2d21..d059c65a 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -128,12 +128,8 @@ class AnchorHead(nn.Module):
     def loss_single(self, cls_score, bbox_pred, labels, label_weights,
                     bbox_targets, bbox_weights, num_total_samples, cfg):
         # classification loss
-        if self.use_sigmoid_cls:
-            labels = labels.reshape(-1, self.cls_out_channels)
-            label_weights = label_weights.reshape(-1, self.cls_out_channels)
-        else:
-            labels = labels.reshape(-1)
-            label_weights = label_weights.reshape(-1)
+        labels = labels.reshape(-1)
+        label_weights = label_weights.reshape(-1)
         cls_score = cls_score.permute(0, 2, 3, 1).reshape(
             -1, self.cls_out_channels)
         if self.use_sigmoid_cls:
diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py
index 9576335a..b3cbc266 100644
--- a/mmdet/ops/__init__.py
+++ b/mmdet/ops/__init__.py
@@ -5,11 +5,12 @@ from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
 from .nms import nms, soft_nms
 from .roi_align import RoIAlign, roi_align
 from .roi_pool import RoIPool, roi_pool
+from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
 
 __all__ = [
     'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
     'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
     'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
     'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
-    'deform_roi_pooling'
+    'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss'
 ]
diff --git a/mmdet/ops/sigmoid_focal_loss/__init__.py b/mmdet/ops/sigmoid_focal_loss/__init__.py
new file mode 100644
index 00000000..d0e5abd9
--- /dev/null
+++ b/mmdet/ops/sigmoid_focal_loss/__init__.py
@@ -0,0 +1,3 @@
+from .modules.sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
+
+__all__ = ['SigmoidFocalLoss', 'sigmoid_focal_loss']
diff --git a/mmdet/ops/sigmoid_focal_loss/functions/__init__.py b/mmdet/ops/sigmoid_focal_loss/functions/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py
new file mode 100644
index 00000000..803df415
--- /dev/null
+++ b/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py
@@ -0,0 +1,42 @@
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from .. import sigmoid_focal_loss_cuda
+
+
+class SigmoidFocalLossFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input, target, gamma=2.0, alpha=0.25, reduction='mean'):
+        ctx.save_for_backward(input, target)
+        num_classes = input.shape[1]
+        ctx.num_classes = num_classes
+        ctx.gamma = gamma
+        ctx.alpha = alpha
+
+        loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes,
+                                               gamma, alpha)
+        reduction_enum = F._Reduction.get_enum(reduction)
+        # none: 0, mean:1, sum: 2
+        if reduction_enum == 0:
+            return loss
+        elif reduction_enum == 1:
+            return loss.mean()
+        elif reduction_enum == 2:
+            return loss.sum()
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, d_loss):
+        input, target = ctx.saved_tensors
+        num_classes = ctx.num_classes
+        gamma = ctx.gamma
+        alpha = ctx.alpha
+        d_loss = d_loss.contiguous()
+        d_input = sigmoid_focal_loss_cuda.backward(input, target, d_loss,
+                                                   num_classes, gamma, alpha)
+        return d_input, None, None, None, None
+
+
+sigmoid_focal_loss = SigmoidFocalLossFunction.apply
diff --git a/mmdet/ops/sigmoid_focal_loss/modules/__init__.py b/mmdet/ops/sigmoid_focal_loss/modules/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py
new file mode 100644
index 00000000..3caff399
--- /dev/null
+++ b/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py
@@ -0,0 +1,23 @@
+from torch import nn
+
+from ..functions.sigmoid_focal_loss import sigmoid_focal_loss
+
+
+class SigmoidFocalLoss(nn.Module):
+
+    def __init__(self, gamma, alpha):
+        super(SigmoidFocalLoss, self).__init__()
+        self.gamma = gamma
+        self.alpha = alpha
+
+    def forward(self, logits, targets):
+        assert logits.is_cuda
+        loss = sigmoid_focal_loss(logits, targets, self.gamma, self.alpha)
+        return loss.sum()
+
+    def __repr__(self):
+        tmpstr = self.__class__.__name__ + "("
+        tmpstr += "gamma=" + str(self.gamma)
+        tmpstr += ", alpha=" + str(self.alpha)
+        tmpstr += ")"
+        return tmpstr
diff --git a/mmdet/ops/sigmoid_focal_loss/setup.py b/mmdet/ops/sigmoid_focal_loss/setup.py
new file mode 100644
index 00000000..a70c6545
--- /dev/null
+++ b/mmdet/ops/sigmoid_focal_loss/setup.py
@@ -0,0 +1,12 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+    name='SigmoidFocalLoss',
+    ext_modules=[
+        CUDAExtension('sigmoid_focal_loss_cuda', [
+            'src/sigmoid_focal_loss.cpp',
+            'src/sigmoid_focal_loss_cuda.cu',
+        ]),
+    ],
+    cmdclass={'build_ext': BuildExtension})
diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp
new file mode 100644
index 00000000..20427518
--- /dev/null
+++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp
@@ -0,0 +1,43 @@
+// modify from
+// https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h
+#include <torch/extension.h>
+
+at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
+                                         const at::Tensor &targets,
+                                         const int num_classes,
+                                         const float gamma, const float alpha);
+
+at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
+                                          const at::Tensor &targets,
+                                          const at::Tensor &d_losses,
+                                          const int num_classes,
+                                          const float gamma, const float alpha);
+
+// Interface for Python
+at::Tensor SigmoidFocalLoss_forward(const at::Tensor &logits,
+                                    const at::Tensor &targets,
+                                    const int num_classes, const float gamma,
+                                    const float alpha) {
+  if (logits.type().is_cuda()) {
+    return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma,
+                                         alpha);
+  }
+}
+
+at::Tensor SigmoidFocalLoss_backward(const at::Tensor &logits,
+                                     const at::Tensor &targets,
+                                     const at::Tensor &d_losses,
+                                     const int num_classes, const float gamma,
+                                     const float alpha) {
+  if (logits.type().is_cuda()) {
+    return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses,
+                                          num_classes, gamma, alpha);
+  }
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &SigmoidFocalLoss_forward,
+        "SigmoidFocalLoss forward (CUDA)");
+  m.def("backward", &SigmoidFocalLoss_backward,
+        "SigmoidFocalLoss backward (CUDA)");
+}
diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
new file mode 100644
index 00000000..aa1e4b9d
--- /dev/null
+++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
@@ -0,0 +1,169 @@
+// modify from
+// https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
+
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This file is modified from
+// https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
+// Cheng-Yang Fu
+// cyfu@cs.unc.edu
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCAtomics.cuh>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <cfloat>
+
+// TODO make it in a common file
+#define CUDA_1D_KERNEL_LOOP(i, n)                            \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+       i += blockDim.x * gridDim.x)
+
+template <typename scalar_t>
+__global__ void SigmoidFocalLossForward(const int nthreads,
+                                        const scalar_t *logits,
+                                        const long *targets,
+                                        const int num_classes,
+                                        const float gamma, const float alpha,
+                                        const int num, scalar_t *losses) {
+  CUDA_1D_KERNEL_LOOP(i, nthreads) {
+    int n = i / num_classes;
+    int d = i % num_classes;  // current class[0~79];
+    int t = targets[n];       // target class [1~80];
+
+    // Decide it is positive or negative case.
+    scalar_t c1 = (t == (d + 1));
+    scalar_t c2 = (t >= 0 & t != (d + 1));
+
+    scalar_t zn = (1.0 - alpha);
+    scalar_t zp = (alpha);
+
+    // p = 1. / 1. + expf(-x); p = sigmoid(x)
+    scalar_t p = 1. / (1. + expf(-logits[i]));
+
+    // (1-p)**gamma * log(p) where
+    scalar_t term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
+
+    // p**gamma * log(1-p)
+    scalar_t term2 =
+        powf(p, gamma) *
+        (-1. * logits[i] * (logits[i] >= 0) -
+         logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
+
+    losses[i] = 0.0;
+    losses[i] += -c1 * term1 * zp;
+    losses[i] += -c2 * term2 * zn;
+
+  }  // CUDA_1D_KERNEL_LOOP
+}  // SigmoidFocalLossForward
+
+template <typename scalar_t>
+__global__ void SigmoidFocalLossBackward(
+    const int nthreads, const scalar_t *logits, const long *targets,
+    const scalar_t *d_losses, const int num_classes, const float gamma,
+    const float alpha, const int num, scalar_t *d_logits) {
+  CUDA_1D_KERNEL_LOOP(i, nthreads) {
+    int n = i / num_classes;
+    int d = i % num_classes;  // current class[0~79];
+    int t = targets[n];       // target class [1~80], 0 is background;
+
+    // Decide it is positive or negative case.
+    scalar_t c1 = (t == (d + 1));
+    scalar_t c2 = (t >= 0 & t != (d + 1));
+
+    scalar_t zn = (1.0 - alpha);
+    scalar_t zp = (alpha);
+    // p = 1. / 1. + expf(-x); p = sigmoid(x)
+    scalar_t p = 1. / (1. + expf(-logits[i]));
+
+    // (1-p)**g * (1 - p - g*p*log(p)
+    scalar_t term1 =
+        powf((1. - p), gamma) * (1. - p - (p * gamma * logf(max(p, FLT_MIN))));
+
+    // (p**g) * (g*(1-p)*log(1-p) - p)
+    scalar_t term2 =
+        powf(p, gamma) *
+        ((-1. * logits[i] * (logits[i] >= 0) -
+          logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
+             (1. - p) * gamma -
+         p);
+    d_logits[i] = 0.0;
+    d_logits[i] += -c1 * term1 * zp;
+    d_logits[i] += -c2 * term2 * zn;
+    d_logits[i] = d_logits[i] * d_losses[i];
+
+  }  // CUDA_1D_KERNEL_LOOP
+}  // SigmoidFocalLossBackward
+
+at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
+                                         const at::Tensor &targets,
+                                         const int num_classes,
+                                         const float gamma, const float alpha) {
+  AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
+  AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
+  AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
+
+  const int num_samples = logits.size(0);
+
+  auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
+  auto losses_size = num_samples * logits.size(1);
+
+  dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L));
+  dim3 block(512);
+
+  if (losses.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return losses;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      logits.type(), "SigmoidFocalLoss_forward", [&] {
+        SigmoidFocalLossForward<scalar_t><<<grid, block>>>(
+            losses_size, logits.contiguous().data<scalar_t>(),
+            targets.contiguous().data<long>(), num_classes, gamma, alpha,
+            num_samples, losses.data<scalar_t>());
+      });
+  THCudaCheck(cudaGetLastError());
+  return losses;
+}
+
+at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
+                                          const at::Tensor &targets,
+                                          const at::Tensor &d_losses,
+                                          const int num_classes,
+                                          const float gamma,
+                                          const float alpha) {
+  AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
+  AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
+  AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor");
+
+  AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
+
+  const int num_samples = logits.size(0);
+  AT_ASSERTM(logits.size(1) == num_classes,
+             "logits.size(1) should be num_classes");
+
+  auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
+  auto d_logits_size = num_samples * logits.size(1);
+
+  dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L));
+  dim3 block(512);
+
+  if (d_logits.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return d_logits;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      logits.type(), "SigmoidFocalLoss_backward", [&] {
+        SigmoidFocalLossBackward<scalar_t><<<grid, block>>>(
+            d_logits_size, logits.contiguous().data<scalar_t>(),
+            targets.contiguous().data<long>(),
+            d_losses.contiguous().data<scalar_t>(), num_classes, gamma, alpha,
+            num_samples, d_logits.data<scalar_t>());
+      });
+
+  THCudaCheck(cudaGetLastError());
+  return d_logits;
+}
-- 
GitLab