Skip to content
Snippets Groups Projects
Unverified Commit ef58bc62 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Rewrite Soft NMS with pytorch extension (#2056)

* add soft_nms_cpu

* remove cython dependency

* merge soft_nms_cpu.cpp into nms_cpu.cpp

* add docs

* fixed typo

* fixed typo

* fixed typo
parent 9790afd1
No related branches found
No related tags found
No related merge requests found
...@@ -103,8 +103,6 @@ venv.bak/ ...@@ -103,8 +103,6 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
# cython generated cpp
mmdet/ops/nms/src/soft_nms_cpu.cpp
mmdet/version.py mmdet/version.py
data data
.vscode .vscode
......
...@@ -3,6 +3,6 @@ line_length = 79 ...@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmdet known_first_party = mmdet
known_third_party = Cython,asynctest,cv2,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision known_third_party = asynctest,cv2,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import torch import torch
from . import nms_cpu, nms_cuda from . import nms_cpu, nms_cuda
from .soft_nms_cpu import soft_nms_cpu
def nms(dets, iou_thr, device_id=None): def nms(dets, iou_thr, device_id=None):
...@@ -31,8 +30,8 @@ def nms(dets, iou_thr, device_id=None): ...@@ -31,8 +30,8 @@ def nms(dets, iou_thr, device_id=None):
>>> [35.3, 11.5, 39.9, 14.5, 0.4], >>> [35.3, 11.5, 39.9, 14.5, 0.4],
>>> [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32) >>> [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32)
>>> iou_thr = 0.7 >>> iou_thr = 0.7
>>> supressed, inds = nms(dets, iou_thr) >>> suppressed, inds = nms(dets, iou_thr)
>>> assert len(inds) == len(supressed) == 3 >>> assert len(inds) == len(suppressed) == 3
""" """
# convert dets (tensor or numpy array) to tensor # convert dets (tensor or numpy array) to tensor
if isinstance(dets, torch.Tensor): if isinstance(dets, torch.Tensor):
...@@ -62,7 +61,22 @@ def nms(dets, iou_thr, device_id=None): ...@@ -62,7 +61,22 @@ def nms(dets, iou_thr, device_id=None):
def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
""" """Dispatch to only CPU Soft NMS implementations.
The input can be either a torch tensor or numpy array.
The returned type will always be the same as inputs.
Arguments:
dets (torch.Tensor or np.ndarray): bboxes with scores.
iou_thr (float): IoU threshold for Soft NMS.
method (str): either 'linear' or 'gaussian'
sigma (float): hyperparameter for gaussian method
min_score (float): score filter threshold
Returns:
tuple: new det bboxes and indice, which is always the same
data type as the input.
Example: Example:
>>> dets = np.array([[4., 3., 5., 3., 0.9], >>> dets = np.array([[4., 3., 5., 3., 0.9],
>>> [4., 3., 5., 4., 0.9], >>> [4., 3., 5., 4., 0.9],
...@@ -71,15 +85,16 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): ...@@ -71,15 +85,16 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
>>> [3., 1., 3., 1., 0.4], >>> [3., 1., 3., 1., 0.4],
>>> [3., 1., 3., 1., 0.0]], dtype=np.float32) >>> [3., 1., 3., 1., 0.0]], dtype=np.float32)
>>> iou_thr = 0.7 >>> iou_thr = 0.7
>>> supressed, inds = soft_nms(dets, iou_thr, sigma=0.5) >>> new_dets, inds = soft_nms(dets, iou_thr, sigma=0.5)
>>> assert len(inds) == len(supressed) == 3 >>> assert len(inds) == len(new_dets) == 3
""" """
# convert dets (tensor or numpy array) to tensor
if isinstance(dets, torch.Tensor): if isinstance(dets, torch.Tensor):
is_tensor = True is_tensor = True
dets_np = dets.detach().cpu().numpy() dets_t = dets.detach().cpu()
elif isinstance(dets, np.ndarray): elif isinstance(dets, np.ndarray):
is_tensor = False is_tensor = False
dets_np = dets dets_t = torch.from_numpy(dets)
else: else:
raise TypeError( raise TypeError(
'dets must be either a Tensor or numpy array, but got {}'.format( 'dets must be either a Tensor or numpy array, but got {}'.format(
...@@ -88,15 +103,15 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): ...@@ -88,15 +103,15 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
method_codes = {'linear': 1, 'gaussian': 2} method_codes = {'linear': 1, 'gaussian': 2}
if method not in method_codes: if method not in method_codes:
raise ValueError('Invalid method for SoftNMS: {}'.format(method)) raise ValueError('Invalid method for SoftNMS: {}'.format(method))
new_dets, inds = soft_nms_cpu( results = nms_cpu.soft_nms(dets_t, iou_thr, method_codes[method], sigma,
dets_np, min_score)
iou_thr,
method=method_codes[method], new_dets = results[:, :5]
sigma=sigma, inds = results[:, 5]
min_score=min_score)
if is_tensor: if is_tensor:
return dets.new_tensor(new_dets), dets.new_tensor( return dets.new_tensor(new_dets), dets.new_tensor(
inds, dtype=torch.long) inds, dtype=torch.long)
else: else:
return new_dets.astype(np.float32), inds.astype(np.int64) return new_dets.numpy().astype(dets.dtype), inds.numpy().astype(
np.int64)
// Modified from https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx, Soft-NMS is added
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <torch/extension.h> #include <torch/extension.h>
...@@ -66,6 +67,149 @@ at::Tensor nms(const at::Tensor& dets, const float threshold) { ...@@ -66,6 +67,149 @@ at::Tensor nms(const at::Tensor& dets, const float threshold) {
return result; return result;
} }
template <typename scalar_t>
at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
const unsigned char method, const float sigma, const float min_score) {
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
}
auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();
auto scores_t = dets.select(1, 4).contiguous();
at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
auto ndets = dets.size(0);
auto x1 = x1_t.data<scalar_t>();
auto y1 = y1_t.data<scalar_t>();
auto x2 = x2_t.data<scalar_t>();
auto y2 = y2_t.data<scalar_t>();
auto scores = scores_t.data<scalar_t>();
auto areas = areas_t.data<scalar_t>();
int64_t pos = 0;
at::Tensor inds_t = at::arange(ndets, dets.options());
auto inds = inds_t.data<scalar_t>();
for (int64_t i = 0; i < ndets; i++) {
auto max_score = scores[i];
auto max_pos = i;
auto ix1 = x1[i];
auto iy1 = y1[i];
auto ix2 = x2[i];
auto iy2 = y2[i];
auto iscore = scores[i];
auto iarea = areas[i];
auto iind = inds[i];
pos = i + 1;
// get max box
while (pos < ndets){
if (max_score < scores[pos]) {
max_score = scores[pos];
max_pos = pos;
}
pos = pos + 1;
}
// add max box as a detection
x1[i] = x1[max_pos];
y1[i] = y1[max_pos];
x2[i] = x2[max_pos];
y2[i] = y2[max_pos];
scores[i] = scores[max_pos];
areas[i] = areas[max_pos];
inds[i] = inds[max_pos];
// swap ith box with position of max box
x1[max_pos] = ix1;
y1[max_pos] = iy1;
x2[max_pos] = ix2;
y2[max_pos] = iy2;
scores[max_pos] = iscore;
areas[max_pos] = iarea;
inds[max_pos] = iind;
ix1 = x1[i];
iy1 = y1[i];
ix2 = x2[i];
iy2 = y2[i];
iscore = scores[i];
iarea = areas[i];
pos = i + 1;
// NMS iterations, note that N changes if detection boxes fall below threshold
while (pos < ndets) {
auto xx1 = std::max(ix1, x1[pos]);
auto yy1 = std::max(iy1, y1[pos]);
auto xx2 = std::min(ix2, x2[pos]);
auto yy2 = std::min(iy2, y2[pos]);
auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[pos] - inter);
scalar_t weight = 1.;
if (method == 1) {
if (ovr > threshold) weight = 1 - ovr;
}
else if (method == 2) {
weight = std::exp(-(ovr * ovr) / sigma);
}
else {
// original NMS
if (ovr > threshold) {
weight = 0;
}
else {
weight = 1;
}
}
scores[pos] = weight * scores[pos];
// if box score falls below threshold, discard the box by
// swapping with last box update N
if (scores[pos] < min_score) {
x1[pos] = x1[ndets - 1];
y1[pos] = y1[ndets - 1];
x2[pos] = x2[ndets - 1];
y2[pos] = y2[ndets - 1];
scores[pos] = scores[ndets - 1];
areas[pos] = areas[ndets - 1];
inds[pos] = inds[ndets - 1];
ndets = ndets -1;
pos = pos - 1;
}
pos = pos + 1;
}
}
at::Tensor result = at::zeros({6, ndets}, dets.options());
result[0] = x1_t.slice(0, 0, ndets);
result[1] = y1_t.slice(0, 0, ndets);
result[2] = x2_t.slice(0, 0, ndets);
result[3] = y2_t.slice(0, 0, ndets);
result[4] = scores_t.slice(0, 0, ndets);
result[5] = inds_t.slice(0, 0, ndets);
result =result.t().contiguous();
return result;
}
at::Tensor soft_nms(const at::Tensor& dets, const float threshold,
const unsigned char method, const float sigma, const float min_score) {
at::Tensor result;
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] {
result = soft_nms_cpu_kernel<scalar_t>(dets, threshold, method, sigma, min_score);
});
return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression"); m.def("nms", &nms, "non-maximum suppression");
} m.def("soft_nms", &soft_nms, "soft non-maximum suppression");
\ No newline at end of file }
# ----------------------------------------------------------
# Soft-NMS: Improving Object Detection With One Line of Code
# Copyright (c) University of Maryland, College Park
# Licensed under The MIT License [see LICENSE for details]
# Written by Navaneeth Bodla and Bharat Singh
# Modified by Kai Chen
# ----------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
def soft_nms_cpu(
np.ndarray[float, ndim=2] boxes_in,
float iou_thr,
unsigned int method=1,
float sigma=0.5,
float min_score=0.001,
):
boxes = boxes_in.copy()
cdef int N = boxes.shape[0]
cdef float iw, ih, box_area
cdef float ua
cdef int pos = 0
cdef float maxscore = 0
cdef int maxpos = 0
cdef float x1, x2, y1, y2, tx1, tx2, ty1, ty2, ts, area, weight, ov
inds = np.arange(N)
for i in range(N):
maxscore = boxes[i, 4]
maxpos = i
tx1 = boxes[i, 0]
ty1 = boxes[i, 1]
tx2 = boxes[i, 2]
ty2 = boxes[i, 3]
ts = boxes[i, 4]
ti = inds[i]
pos = i + 1
# get max box
while pos < N:
if maxscore < boxes[pos, 4]:
maxscore = boxes[pos, 4]
maxpos = pos
pos = pos + 1
# add max box as a detection
boxes[i, 0] = boxes[maxpos, 0]
boxes[i, 1] = boxes[maxpos, 1]
boxes[i, 2] = boxes[maxpos, 2]
boxes[i, 3] = boxes[maxpos, 3]
boxes[i, 4] = boxes[maxpos, 4]
inds[i] = inds[maxpos]
# swap ith box with position of max box
boxes[maxpos, 0] = tx1
boxes[maxpos, 1] = ty1
boxes[maxpos, 2] = tx2
boxes[maxpos, 3] = ty2
boxes[maxpos, 4] = ts
inds[maxpos] = ti
tx1 = boxes[i, 0]
ty1 = boxes[i, 1]
tx2 = boxes[i, 2]
ty2 = boxes[i, 3]
ts = boxes[i, 4]
pos = i + 1
# NMS iterations, note that N changes if detection boxes fall below
# threshold
while pos < N:
x1 = boxes[pos, 0]
y1 = boxes[pos, 1]
x2 = boxes[pos, 2]
y2 = boxes[pos, 3]
s = boxes[pos, 4]
area = (x2 - x1 + 1) * (y2 - y1 + 1)
iw = (min(tx2, x2) - max(tx1, x1) + 1)
if iw > 0:
ih = (min(ty2, y2) - max(ty1, y1) + 1)
if ih > 0:
ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
ov = iw * ih / ua # iou between max box and detection box
if method == 1: # linear
if ov > iou_thr:
weight = 1 - ov
else:
weight = 1
elif method == 2: # gaussian
weight = np.exp(-(ov * ov) / sigma)
else: # original NMS
if ov > iou_thr:
weight = 0
else:
weight = 1
boxes[pos, 4] = weight * boxes[pos, 4]
# if box score falls below threshold, discard the box by
# swapping with last box update N
if boxes[pos, 4] < min_score:
boxes[pos, 0] = boxes[N-1, 0]
boxes[pos, 1] = boxes[N-1, 1]
boxes[pos, 2] = boxes[N-1, 2]
boxes[pos, 3] = boxes[N-1, 3]
boxes[pos, 4] = boxes[N-1, 4]
inds[pos] = inds[N - 1]
N = N - 1
pos = pos - 1
pos = pos + 1
return boxes[:N], inds[:N]
# These must be installed before building mmdetection # These must be installed before building mmdetection
cython
numpy numpy
torch>=1.1 torch>=1.1
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import platform
import subprocess import subprocess
import time import time
from setuptools import Extension, dist, find_packages, setup from setuptools import find_packages, setup
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1'])
import numpy as np # noqa: E402, isort:skip
from Cython.Build import cythonize # noqa: E402, isort:skip
def readme(): def readme():
with open('README.md', encoding='utf-8') as f: with open('README.md', encoding='utf-8') as f:
...@@ -116,23 +111,6 @@ def make_cuda_ext(name, module, sources): ...@@ -116,23 +111,6 @@ def make_cuda_ext(name, module, sources):
}) })
def make_cython_ext(name, module, sources):
extra_compile_args = None
if platform.system() != 'Windows':
extra_compile_args = {
'cxx': ['-Wno-unused-function', '-Wno-write-strings']
}
extension = Extension(
'{}.{}'.format(module, name),
[os.path.join(*module.split('.'), p) for p in sources],
include_dirs=[np.get_include()],
language='c++',
extra_compile_args=extra_compile_args)
extension, = cythonize(extension)
return extension
def parse_requirements(fname='requirements.txt', with_version=True): def parse_requirements(fname='requirements.txt', with_version=True):
""" """
Parse the package dependencies listed in a requirements file but strips Parse the package dependencies listed in a requirements file but strips
...@@ -249,10 +227,6 @@ if __name__ == '__main__': ...@@ -249,10 +227,6 @@ if __name__ == '__main__':
name='compiling_info', name='compiling_info',
module='mmdet.ops.utils', module='mmdet.ops.utils',
sources=['src/compiling_info.cpp']), sources=['src/compiling_info.cpp']),
make_cython_ext(
name='soft_nms_cpu',
module='mmdet.ops.nms',
sources=['src/soft_nms_cpu.pyx']),
make_cuda_ext( make_cuda_ext(
name='nms_cpu', name='nms_cpu',
module='mmdet.ops.nms', module='mmdet.ops.nms',
......
"""
CommandLine:
pytest tests/test_soft_nms.py
"""
import numpy as np
import torch
from mmdet.ops.nms.nms_wrapper import soft_nms
def test_soft_nms_device_and_dtypes_cpu():
"""
CommandLine:
xdoctest -m tests/test_soft_nms.py test_soft_nms_device_and_dtypes_cpu
"""
iou_thr = 0.7
base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
[49.3, 32.9, 51.0, 35.3, 0.9],
[35.3, 11.5, 39.9, 14.5, 0.4],
[35.2, 11.7, 39.7, 15.7, 0.3]])
# CPU can handle float32 and float64
dets = base_dets.astype(np.float32)
new_dets, inds = soft_nms(dets, iou_thr)
assert dets.dtype == new_dets.dtype
assert len(inds) == len(new_dets) == 4
dets = torch.FloatTensor(base_dets)
new_dets, inds = soft_nms(dets, iou_thr)
assert dets.dtype == new_dets.dtype
assert len(inds) == len(new_dets) == 4
dets = base_dets.astype(np.float64)
new_dets, inds = soft_nms(dets, iou_thr)
assert dets.dtype == new_dets.dtype
assert len(inds) == len(new_dets) == 4
dets = torch.DoubleTensor(base_dets)
new_dets, inds = soft_nms(dets, iou_thr)
assert dets.dtype == new_dets.dtype
assert len(inds) == len(new_dets) == 4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment