Skip to content
Snippets Groups Projects
Unverified Commit 58f04245 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add condition on RoIPool backward (#2332)

* add condition on RoIPool backward

* format

* delete space

* reformat

* add newline end of file
parent d132051d
No related branches found
No related tags found
No related merge requests found
......@@ -93,16 +93,14 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
scalar_t *top_data = output.data<scalar_t>();
int *argmax_data = argmax.data<int>();
ROIPoolForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
channels, height, width, pooled_h, pooled_w, top_data,
argmax_data);
ROIPoolForward<scalar_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK,
0, at::cuda::getCurrentCUDAStream()>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
channels, height, width, pooled_h, pooled_w, top_data, argmax_data);
}));
THCudaCheck(cudaGetLastError());
return 1;
}
template <typename scalar_t>
__global__ void ROIPoolBackward(const int nthreads, const scalar_t *top_diff,
const scalar_t *rois, const int *argmax_data,
......@@ -115,17 +113,16 @@ __global__ void ROIPoolBackward(const int nthreads, const scalar_t *top_diff,
int ph = (index / pooled_w) % pooled_h;
int c = (index / pooled_w / pooled_h) % channels;
int n = index / pooled_w / pooled_h / channels;
int roi_batch_ind = rois[n * 5];
int bottom_index = argmax_data[(n * channels + c) * pooled_h * pooled_w +
ph * pooled_w + pw];
atomicAdd(bottom_diff + (roi_batch_ind * channels + c) * height * width +
bottom_index,
top_diff[index]);
if (bottom_index != -1) {
atomicAdd(bottom_diff + (roi_batch_ind * channels + c) * height * width +
bottom_index,
top_diff[index]);
}
}
}
int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const at::Tensor argmax, const float spatial_scale,
const int batch_size, const int channels,
......@@ -133,24 +130,21 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int num_rois, const int pooled_h,
const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.scalar_type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
const int *argmax_data = argmax.data<int>();
scalar_t *bottom_diff = bottom_grad.data<scalar_t>();
if (sizeof(scalar_t) == sizeof(double)) {
fprintf(stderr, "double is not supported\n");
exit(-1);
}
ROIPoolBackward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, top_diff, rois_data, argmax_data,
scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff);
ROIPoolBackward<scalar_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK,
0, at::cuda::getCurrentCUDAStream()>>>(
output_size, top_diff, rois_data, argmax_data,
scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff);
}));
THCudaCheck(cudaGetLastError());
return 1;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment