Skip to content
Snippets Groups Projects
Commit d67a2e16 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

add cpp sigmoid focal loss (#490)

* add cpp sigmoid focal loss

* modify interface

* format cpp code, support, fix pep8 error

* format cpp code as google style
parent d5a6b5dc
No related branches found
No related tags found
No related merge requests found
Showing
with 320 additions and 25 deletions
......@@ -29,3 +29,10 @@ if [ -d "build" ]; then
rm -r build
fi
$PYTHON setup.py build_ext --inplace
echo "Building sigmoid focal loss op..."
cd ../sigmoid_focal_loss
if [ -d "build" ]; then
rm -r build
fi
$PYTHON setup.py build_ext --inplace
......@@ -152,9 +152,6 @@ def anchor_target_single(flat_anchors,
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if label_channels > 1:
labels, label_weights = expand_binary_labels(
labels, label_weights, label_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
......
from .losses import (weighted_nll_loss, weighted_cross_entropy,
weighted_binary_cross_entropy, sigmoid_focal_loss,
weighted_sigmoid_focal_loss, mask_cross_entropy,
smooth_l1_loss, weighted_smoothl1, accuracy)
from .losses import (
weighted_nll_loss, weighted_cross_entropy, weighted_binary_cross_entropy,
sigmoid_focal_loss, py_sigmoid_focal_loss, weighted_sigmoid_focal_loss,
mask_cross_entropy, smooth_l1_loss, weighted_smoothl1, accuracy)
__all__ = [
'weighted_nll_loss', 'weighted_cross_entropy',
'weighted_binary_cross_entropy', 'sigmoid_focal_loss',
'weighted_sigmoid_focal_loss', 'mask_cross_entropy', 'smooth_l1_loss',
'weighted_smoothl1', 'accuracy'
'py_sigmoid_focal_loss', 'weighted_sigmoid_focal_loss',
'mask_cross_entropy', 'smooth_l1_loss', 'weighted_smoothl1', 'accuracy'
]
......@@ -2,6 +2,8 @@
import torch
import torch.nn.functional as F
from ...ops import sigmoid_focal_loss
def weighted_nll_loss(pred, label, weight, avg_factor=None):
if avg_factor is None:
......@@ -28,12 +30,12 @@ def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
reduction='sum')[None] / avg_factor
def sigmoid_focal_loss(pred,
target,
weight,
gamma=2.0,
alpha=0.25,
reduction='mean'):
def py_sigmoid_focal_loss(pred,
target,
weight,
gamma=2.0,
alpha=0.25,
reduction='mean'):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
......@@ -60,9 +62,9 @@ def weighted_sigmoid_focal_loss(pred,
num_classes=80):
if avg_factor is None:
avg_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6
return sigmoid_focal_loss(
pred, target, weight, gamma=gamma, alpha=alpha,
reduction='sum')[None] / avg_factor
return torch.sum(
sigmoid_focal_loss(pred, target, gamma, alpha, 'none') * weight.view(
-1, 1))[None] / avg_factor
def mask_cross_entropy(pred, target, label):
......
......@@ -128,12 +128,8 @@ class AnchorHead(nn.Module):
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
if self.use_sigmoid_cls:
labels = labels.reshape(-1, self.cls_out_channels)
label_weights = label_weights.reshape(-1, self.cls_out_channels)
else:
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels)
if self.use_sigmoid_cls:
......
......@@ -5,11 +5,12 @@ from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
__all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling'
'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss'
]
from .modules.sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
__all__ = ['SigmoidFocalLoss', 'sigmoid_focal_loss']
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from .. import sigmoid_focal_loss_cuda
class SigmoidFocalLossFunction(Function):
@staticmethod
def forward(ctx, input, target, gamma=2.0, alpha=0.25, reduction='mean'):
ctx.save_for_backward(input, target)
num_classes = input.shape[1]
ctx.num_classes = num_classes
ctx.gamma = gamma
ctx.alpha = alpha
loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes,
gamma, alpha)
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
input, target = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
alpha = ctx.alpha
d_loss = d_loss.contiguous()
d_input = sigmoid_focal_loss_cuda.backward(input, target, d_loss,
num_classes, gamma, alpha)
return d_input, None, None, None, None
sigmoid_focal_loss = SigmoidFocalLossFunction.apply
from torch import nn
from ..functions.sigmoid_focal_loss import sigmoid_focal_loss
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
assert logits.is_cuda
loss = sigmoid_focal_loss(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='SigmoidFocalLoss',
ext_modules=[
CUDAExtension('sigmoid_focal_loss_cuda', [
'src/sigmoid_focal_loss.cpp',
'src/sigmoid_focal_loss_cuda.cu',
]),
],
cmdclass={'build_ext': BuildExtension})
// modify from
// https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h
#include <torch/extension.h>
at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
const at::Tensor &targets,
const int num_classes,
const float gamma, const float alpha);
at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
const at::Tensor &targets,
const at::Tensor &d_losses,
const int num_classes,
const float gamma, const float alpha);
// Interface for Python
at::Tensor SigmoidFocalLoss_forward(const at::Tensor &logits,
const at::Tensor &targets,
const int num_classes, const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma,
alpha);
}
}
at::Tensor SigmoidFocalLoss_backward(const at::Tensor &logits,
const at::Tensor &targets,
const at::Tensor &d_losses,
const int num_classes, const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses,
num_classes, gamma, alpha);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &SigmoidFocalLoss_forward,
"SigmoidFocalLoss forward (CUDA)");
m.def("backward", &SigmoidFocalLoss_backward,
"SigmoidFocalLoss backward (CUDA)");
}
// modify from
// https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
// This file is modified from
// https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
// Cheng-Yang Fu
// cyfu@cs.unc.edu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <cfloat>
// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
template <typename scalar_t>
__global__ void SigmoidFocalLossForward(const int nthreads,
const scalar_t *logits,
const long *targets,
const int num_classes,
const float gamma, const float alpha,
const int num, scalar_t *losses) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80];
// Decide it is positive or negative case.
scalar_t c1 = (t == (d + 1));
scalar_t c2 = (t >= 0 & t != (d + 1));
scalar_t zn = (1.0 - alpha);
scalar_t zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
scalar_t p = 1. / (1. + expf(-logits[i]));
// (1-p)**gamma * log(p) where
scalar_t term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
// p**gamma * log(1-p)
scalar_t term2 =
powf(p, gamma) *
(-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
losses[i] = 0.0;
losses[i] += -c1 * term1 * zp;
losses[i] += -c2 * term2 * zn;
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossForward
template <typename scalar_t>
__global__ void SigmoidFocalLossBackward(
const int nthreads, const scalar_t *logits, const long *targets,
const scalar_t *d_losses, const int num_classes, const float gamma,
const float alpha, const int num, scalar_t *d_logits) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80], 0 is background;
// Decide it is positive or negative case.
scalar_t c1 = (t == (d + 1));
scalar_t c2 = (t >= 0 & t != (d + 1));
scalar_t zn = (1.0 - alpha);
scalar_t zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
scalar_t p = 1. / (1. + expf(-logits[i]));
// (1-p)**g * (1 - p - g*p*log(p)
scalar_t term1 =
powf((1. - p), gamma) * (1. - p - (p * gamma * logf(max(p, FLT_MIN))));
// (p**g) * (g*(1-p)*log(1-p) - p)
scalar_t term2 =
powf(p, gamma) *
((-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
(1. - p) * gamma -
p);
d_logits[i] = 0.0;
d_logits[i] += -c1 * term1 * zp;
d_logits[i] += -c2 * term2 * zn;
d_logits[i] = d_logits[i] * d_losses[i];
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossBackward
at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
const at::Tensor &targets,
const int num_classes,
const float gamma, const float alpha) {
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
const int num_samples = logits.size(0);
auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
auto losses_size = num_samples * logits.size(1);
dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L));
dim3 block(512);
if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
return losses;
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logits.type(), "SigmoidFocalLoss_forward", [&] {
SigmoidFocalLossForward<scalar_t><<<grid, block>>>(
losses_size, logits.contiguous().data<scalar_t>(),
targets.contiguous().data<long>(), num_classes, gamma, alpha,
num_samples, losses.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
return losses;
}
at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
const at::Tensor &targets,
const at::Tensor &d_losses,
const int num_classes,
const float gamma,
const float alpha) {
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
const int num_samples = logits.size(0);
AT_ASSERTM(logits.size(1) == num_classes,
"logits.size(1) should be num_classes");
auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1);
dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L));
dim3 block(512);
if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
return d_logits;
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logits.type(), "SigmoidFocalLoss_backward", [&] {
SigmoidFocalLossBackward<scalar_t><<<grid, block>>>(
d_logits_size, logits.contiguous().data<scalar_t>(),
targets.contiguous().data<long>(),
d_losses.contiguous().data<scalar_t>(), num_classes, gamma, alpha,
num_samples, d_logits.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
return d_logits;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment