From 8387aba8b1f5d576c6d982b283f35c5b70a30a25 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Sun, 28 Jul 2019 22:15:52 +0800
Subject: [PATCH] use .scalar_type() instead of .type() to suppress some
 warnings (#1070)

---
 mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu   | 14 +++++++-------
 mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu   |  4 ++--
 .../masked_conv/src/masked_conv2d_kernel.cu    |  4 ++--
 mmdet/ops/nms/src/nms_cpu.cpp                  |  2 +-
 mmdet/ops/nms/src/soft_nms_cpu.pyx             |  2 +-
 mmdet/ops/roi_align/src/roi_align_kernel.cu    |  4 ++--
 mmdet/ops/roi_pool/src/roi_pool_kernel.cu      |  4 ++--
 .../src/sigmoid_focal_loss.cpp                 |  2 ++
 .../src/sigmoid_focal_loss_cuda.cu             |  6 +++---
 setup.py                                       | 18 +++++++++---------
 10 files changed, 31 insertions(+), 29 deletions(-)

diff --git a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu b/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
index fd560163..a2b94286 100644
--- a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+++ b/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -58,7 +58,7 @@
  * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
  */
 
-// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
 
 #include <ATen/ATen.h>
 #include <THC/THCAtomics.cuh>
@@ -256,7 +256,7 @@ void deformable_im2col(
   int channel_per_deformable_group = channels / deformable_group;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data_im.type(), "deformable_im2col_gpu", ([&] {
+      data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
         const scalar_t *data_im_ = data_im.data<scalar_t>();
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         scalar_t *data_col_ = data_col.data<scalar_t>();
@@ -350,7 +350,7 @@ void deformable_col2im(
   int channel_per_deformable_group = channels / deformable_group;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data_col.type(), "deformable_col2im_gpu", ([&] {
+      data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
         const scalar_t *data_col_ = data_col.data<scalar_t>();
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         scalar_t *grad_im_ = grad_im.data<scalar_t>();
@@ -448,7 +448,7 @@ void deformable_col2im_coord(
   int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data_col.type(), "deformable_col2im_coord_gpu", ([&] {
+      data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
         const scalar_t *data_col_ = data_col.data<scalar_t>();
         const scalar_t *data_im_ = data_im.data<scalar_t>();
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
@@ -778,7 +778,7 @@ void modulated_deformable_im2col_cuda(
   const int num_kernels = channels * batch_size * height_col * width_col;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data_im.type(), "modulated_deformable_im2col_gpu", ([&] {
+      data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
         const scalar_t *data_im_ = data_im.data<scalar_t>();
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         const scalar_t *data_mask_ = data_mask.data<scalar_t>();
@@ -810,7 +810,7 @@ void modulated_deformable_col2im_cuda(
   const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data_col.type(), "modulated_deformable_col2im_gpu", ([&] {
+      data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
         const scalar_t *data_col_ = data_col.data<scalar_t>();
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         const scalar_t *data_mask_ = data_mask.data<scalar_t>();
@@ -843,7 +843,7 @@ void modulated_deformable_col2im_coord_cuda(
   const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+      data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
         const scalar_t *data_col_ = data_col.data<scalar_t>();
         const scalar_t *data_im_ = data_im.data<scalar_t>();
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
diff --git a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu b/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu
index e4944600..1922d724 100644
--- a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu
+++ b/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu
@@ -289,7 +289,7 @@ void DeformablePSROIPoolForward(const at::Tensor data,
   const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      data.type(), "deformable_psroi_pool_forward", ([&] {
+      data.scalar_type(), "deformable_psroi_pool_forward", ([&] {
         const scalar_t *bottom_data = data.data<scalar_t>();
         const scalar_t *bottom_rois = bbox.data<scalar_t>();
         const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
@@ -340,7 +340,7 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
   const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] {
+      out_grad.scalar_type(), "deformable_psroi_pool_backward_acc", ([&] {
         const scalar_t *top_diff = out_grad.data<scalar_t>();
         const scalar_t *bottom_data = data.data<scalar_t>();
         const scalar_t *bottom_rois = bbox.data<scalar_t>();
diff --git a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu b/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu
index 394af13e..a0a949dd 100644
--- a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu
+++ b/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu
@@ -57,7 +57,7 @@ int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height,
   const int output_size = mask_cnt * channels;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      bottom_data.type(), "MaskedIm2colLaucherForward", ([&] {
+      bottom_data.scalar_type(), "MaskedIm2colLaucherForward", ([&] {
         const scalar_t *bottom_data_ = bottom_data.data<scalar_t>();
         const long *mask_h_idx_ = mask_h_idx.data<long>();
         const long *mask_w_idx_ = mask_w_idx.data<long>();
@@ -97,7 +97,7 @@ int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height,
   const int output_size = mask_cnt * channels;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      bottom_data.type(), "MaskedCol2imLaucherForward", ([&] {
+      bottom_data.scalar_type(), "MaskedCol2imLaucherForward", ([&] {
         const scalar_t *bottom_data_ = bottom_data.data<scalar_t>();
         const long *mask_h_idx_ = mask_h_idx.data<long>();
         const long *mask_w_idx_ = mask_w_idx.data<long>();
diff --git a/mmdet/ops/nms/src/nms_cpu.cpp b/mmdet/ops/nms/src/nms_cpu.cpp
index 65546ef4..f7cffb49 100644
--- a/mmdet/ops/nms/src/nms_cpu.cpp
+++ b/mmdet/ops/nms/src/nms_cpu.cpp
@@ -60,7 +60,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) {
 
 at::Tensor nms(const at::Tensor& dets, const float threshold) {
   at::Tensor result;
-  AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
+  AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
     result = nms_cpu_kernel<scalar_t>(dets, threshold);
   });
   return result;
diff --git a/mmdet/ops/nms/src/soft_nms_cpu.pyx b/mmdet/ops/nms/src/soft_nms_cpu.pyx
index c35f8f10..97f53f18 100644
--- a/mmdet/ops/nms/src/soft_nms_cpu.pyx
+++ b/mmdet/ops/nms/src/soft_nms_cpu.pyx
@@ -27,7 +27,7 @@ def soft_nms_cpu(
     float min_score=0.001,
 ):
     boxes = boxes_in.copy()
-    cdef unsigned int N = boxes.shape[0]
+    cdef int N = boxes.shape[0]
     cdef float iw, ih, box_area
     cdef float ua
     cdef int pos = 0
diff --git a/mmdet/ops/roi_align/src/roi_align_kernel.cu b/mmdet/ops/roi_align/src/roi_align_kernel.cu
index 46556405..6d3b2790 100644
--- a/mmdet/ops/roi_align/src/roi_align_kernel.cu
+++ b/mmdet/ops/roi_align/src/roi_align_kernel.cu
@@ -131,7 +131,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
                            at::Tensor output) {
   const int output_size = num_rois * pooled_height * pooled_width * channels;
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      features.type(), "ROIAlignLaucherForward", ([&] {
+      features.scalar_type(), "ROIAlignLaucherForward", ([&] {
         const scalar_t *bottom_data = features.data<scalar_t>();
         const scalar_t *rois_data = rois.data<scalar_t>();
         scalar_t *top_data = output.data<scalar_t>();
@@ -274,7 +274,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
   const int output_size = num_rois * pooled_height * pooled_width * channels;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      top_grad.type(), "ROIAlignLaucherBackward", ([&] {
+      top_grad.scalar_type(), "ROIAlignLaucherBackward", ([&] {
         const scalar_t *top_diff = top_grad.data<scalar_t>();
         const scalar_t *rois_data = rois.data<scalar_t>();
         scalar_t *bottom_diff = bottom_grad.data<scalar_t>();
diff --git a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu b/mmdet/ops/roi_pool/src/roi_pool_kernel.cu
index b51bb043..25ba9853 100644
--- a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu
+++ b/mmdet/ops/roi_pool/src/roi_pool_kernel.cu
@@ -86,7 +86,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
   const int output_size = num_rois * channels * pooled_h * pooled_w;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      features.type(), "ROIPoolLaucherForward", ([&] {
+      features.scalar_type(), "ROIPoolLaucherForward", ([&] {
         const scalar_t *bottom_data = features.data<scalar_t>();
         const scalar_t *rois_data = rois.data<scalar_t>();
         scalar_t *top_data = output.data<scalar_t>();
@@ -134,7 +134,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
   const int output_size = num_rois * pooled_h * pooled_w * channels;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      top_grad.type(), "ROIPoolLaucherBackward", ([&] {
+      top_grad.scalar_type(), "ROIPoolLaucherBackward", ([&] {
         const scalar_t *top_diff = top_grad.data<scalar_t>();
         const scalar_t *rois_data = rois.data<scalar_t>();
         const int *argmax_data = argmax.data<int>();
diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp
index 20427518..b5fef270 100644
--- a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp
+++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp
@@ -22,6 +22,7 @@ at::Tensor SigmoidFocalLoss_forward(const at::Tensor &logits,
     return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma,
                                          alpha);
   }
+  AT_ERROR("SigmoidFocalLoss is not implemented on the CPU");
 }
 
 at::Tensor SigmoidFocalLoss_backward(const at::Tensor &logits,
@@ -33,6 +34,7 @@ at::Tensor SigmoidFocalLoss_backward(const at::Tensor &logits,
     return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses,
                                           num_classes, gamma, alpha);
   }
+  AT_ERROR("SigmoidFocalLoss is not implemented on the CPU");
 }
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
index 7b9b8050..c8db6df7 100644
--- a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
+++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
@@ -1,4 +1,4 @@
-// modify from
+// modified from
 // https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
 
 // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
@@ -118,7 +118,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
   }
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      logits.type(), "SigmoidFocalLoss_forward", [&] {
+      logits.scalar_type(), "SigmoidFocalLoss_forward", [&] {
         SigmoidFocalLossForward<scalar_t><<<grid, block>>>(
             losses_size, logits.contiguous().data<scalar_t>(),
             targets.contiguous().data<long>(), num_classes, gamma, alpha,
@@ -156,7 +156,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
   }
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      logits.type(), "SigmoidFocalLoss_backward", [&] {
+      logits.scalar_type(), "SigmoidFocalLoss_backward", [&] {
         SigmoidFocalLossBackward<scalar_t><<<grid, block>>>(
             d_logits_size, logits.contiguous().data<scalar_t>(),
             targets.contiguous().data<long>(),
diff --git a/setup.py b/setup.py
index 73990e40..4c3c9a31 100644
--- a/setup.py
+++ b/setup.py
@@ -122,7 +122,7 @@ if __name__ == '__main__':
     setup(
         name='mmdet',
         version=get_version(),
-        description='Open MMLab Detection Toolbox',
+        description='Open MMLab Detection Toolbox and Benchmark',
         long_description=readme(),
         keywords='computer vision, object detection',
         url='https://github.com/open-mmlab/mmdetection',
@@ -151,14 +151,6 @@ if __name__ == '__main__':
                 name='soft_nms_cpu',
                 module='mmdet.ops.nms',
                 sources=['src/soft_nms_cpu.pyx']),
-            make_cuda_ext(
-                name='roi_align_cuda',
-                module='mmdet.ops.roi_align',
-                sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu']),
-            make_cuda_ext(
-                name='roi_pool_cuda',
-                module='mmdet.ops.roi_pool',
-                sources=['src/roi_pool_cuda.cpp', 'src/roi_pool_kernel.cu']),
             make_cuda_ext(
                 name='nms_cpu',
                 module='mmdet.ops.nms',
@@ -167,6 +159,14 @@ if __name__ == '__main__':
                 name='nms_cuda',
                 module='mmdet.ops.nms',
                 sources=['src/nms_cuda.cpp', 'src/nms_kernel.cu']),
+            make_cuda_ext(
+                name='roi_align_cuda',
+                module='mmdet.ops.roi_align',
+                sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu']),
+            make_cuda_ext(
+                name='roi_pool_cuda',
+                module='mmdet.ops.roi_pool',
+                sources=['src/roi_pool_cuda.cpp', 'src/roi_pool_kernel.cu']),
             make_cuda_ext(
                 name='deform_conv_cuda',
                 module='mmdet.ops.dcn',
-- 
GitLab