Skip to content
Snippets Groups Projects
Commit 108fc9e1 authored by Kai Chen's avatar Kai Chen
Browse files

set up the codebase skeleton (WIP)

parent 6985ef31
No related branches found
No related tags found
No related merge requests found
Showing
with 1916 additions and 0 deletions
...@@ -102,3 +102,6 @@ venv.bak/ ...@@ -102,3 +102,6 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
# cython generated cpp
mmdet/ops/nms/*.cpp
\ No newline at end of file
#!/usr/bin/env bash
PYTHON=${PYTHON:-"python"}
echo "Building roi align op..."
cd mmdet/ops/roi_align
if [ -d "build" ]; then
rm -r build
fi
$PYTHON setup.py build_ext --inplace
echo "Building roi pool op..."
cd ../roi_pool
if [ -d "build" ]; then
rm -r build
fi
$PYTHON setup.py build_ext --inplace
echo "Building nms op..."
cd ../nms
make clean
make PYTHON=${PYTHON}
from .version import __version__
from .anchor_generator import *
from .bbox_ops import *
from .mask_ops import *
from .eval import *
from .nn import *
from .targets import *
import torch
class AnchorGenerator(object):
def __init__(self, base_size, scales, ratios, scale_major=True):
self.base_size = base_size
self.scales = torch.Tensor(scales)
self.ratios = torch.Tensor(ratios)
self.scale_major = scale_major
self.base_anchors = self.gen_base_anchors()
@property
def num_base_anchors(self):
return self.base_anchors.size(0)
def gen_base_anchors(self):
base_anchor = torch.Tensor(
[0, 0, self.base_size - 1, self.base_size - 1])
w = base_anchor[2] - base_anchor[0] + 1
h = base_anchor[3] - base_anchor[1] + 1
x_ctr = base_anchor[0] + 0.5 * (w - 1)
y_ctr = base_anchor[1] + 0.5 * (h - 1)
h_ratios = torch.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)
base_anchors = torch.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
dim=-1).round()
return base_anchors
def _meshgrid(self, x, y, row_major=True):
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
if row_major:
return xx, yy
else:
return yy, xx
def grid_anchors(self, featmap_size, stride=16, device='cuda'):
feat_h, feat_w = featmap_size
shift_x = torch.arange(0, feat_w, device=device) * stride
shift_y = torch.arange(0, feat_h, device=device) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
base_anchors = self.base_anchors.to(device)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors
def valid_flags(self, featmap_size, valid_size, device='cuda'):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = valid[:, None].expand(
valid.size(0), self.num_base_anchors).contiguous().view(-1)
return valid
from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps,
bbox_sampling, sample_positives, sample_negatives)
from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip,
bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox)
__all__ = [
'bbox_overlaps', 'random_choice', 'bbox_assign',
'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives',
'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip',
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox'
]
import torch
def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
"""Calculate overlap between two set of bboxes.
If ``is_aligned`` is ``False``, then calculate the ious between each bbox
of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
bboxes1 and bboxes2.
Args:
bboxes1 (Tensor): shape (m, 4)
bboxes2 (Tensor): shape (n, 4), if is_aligned is ``True``, then m and n
must be equal.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
Returns:
ious(Tensor): shape (n, k) if is_aligned == False else shape (n, 1)
"""
assert mode in ['iou', 'iof']
rows = bboxes1.size(0)
cols = bboxes2.size(0)
if is_aligned:
assert rows == cols
if rows * cols == 0:
return bboxes1.new(rows, 1) if is_aligned else bboxes1.new(rows, cols)
if is_aligned:
lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2]
rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2]
wh = (rb - lt + 1).clamp(min=0) # [rows, 2]
overlap = wh[:, 0] * wh[:, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
bboxes1[:, 3] - bboxes1[:, 1] + 1)
if mode == 'iou':
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
bboxes2[:, 3] - bboxes2[:, 1] + 1)
ious = overlap / (area1 + area2 - overlap)
else:
ious = overlap / area1
else:
lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2]
rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2]
wh = (rb - lt + 1).clamp(min=0) # [rows, cols, 2]
overlap = wh[:, :, 0] * wh[:, :, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
bboxes1[:, 3] - bboxes1[:, 1] + 1)
if mode == 'iou':
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
bboxes2[:, 3] - bboxes2[:, 1] + 1)
ious = overlap / (area1[:, None] + area2 - overlap)
else:
ious = overlap / (area1[:, None])
return ious
import numpy as np
import torch
from .geometry import bbox_overlaps
def random_choice(gallery, num):
assert len(gallery) >= num
if isinstance(gallery, list):
gallery = np.array(gallery)
cands = np.arange(len(gallery))
np.random.shuffle(cands)
rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long()
if gallery.is_cuda:
rand_inds = rand_inds.cuda(gallery.get_device())
return gallery[rand_inds]
def bbox_assign(proposals,
gt_bboxes,
gt_crowd_bboxes=None,
gt_labels=None,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=.0,
crowd_thr=-1):
"""Assign a corresponding gt bbox or background to each proposal/anchor
This function assign a gt bbox to every proposal, each proposals will be
assigned with -1, 0, or a positive number. -1 means don't care, 0 means
negative sample, positive number is the index (1-based) of assigned gt.
If gt_crowd_bboxes is not None, proposals which have iof(intersection over foreground)
with crowd bboxes over crowd_thr will be ignored
Args:
proposals(Tensor): proposals or RPN anchors, shape (n, 4)
gt_bboxes(Tensor): shape (k, 4)
gt_crowd_bboxes(Tensor): shape(m, 4)
gt_labels(Tensor, optional): shape (k, )
pos_iou_thr(float): iou threshold for positive bboxes
neg_iou_thr(float or tuple): iou threshold for negative bboxes
min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
for RPN, it is usually set as 0, for Fast R-CNN,
it is usually set as pos_iou_thr
crowd_thr: ignore proposals which have iof(intersection over foreground) with
crowd bboxes over crowd_thr
Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
"""
# calculate overlaps between the proposals and the gt boxes
overlaps = bbox_overlaps(proposals, gt_bboxes)
if overlaps.numel() == 0:
raise ValueError('No gt bbox or proposals')
# ignore proposals according to crowd bboxes
if (crowd_thr > 0) and (gt_crowd_bboxes is
not None) and (gt_crowd_bboxes.numel() > 0):
crowd_overlaps = bbox_overlaps(proposals, gt_crowd_bboxes, mode='iof')
crowd_max_overlaps, _ = crowd_overlaps.max(dim=1)
crowd_bboxes_inds = torch.nonzero(
crowd_max_overlaps > crowd_thr).long()
if crowd_bboxes_inds.numel() > 0:
overlaps[crowd_bboxes_inds, :] = -1
return bbox_assign_via_overlaps(overlaps, gt_labels, pos_iou_thr,
neg_iou_thr, min_pos_iou)
def bbox_assign_via_overlaps(overlaps,
gt_labels=None,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=.0):
"""Assign a corresponding gt bbox or background to each proposal/anchor
This function assign a gt bbox to every proposal, each proposals will be
assigned with -1, 0, or a positive number. -1 means don't care, 0 means
negative sample, positive number is the index (1-based) of assigned gt.
The assignment is done in following steps, the order matters:
1. assign every anchor to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals(may be more than one)
to itself
Args:
overlaps(Tensor): overlaps between n proposals and k gt_bboxes, shape(n, k)
gt_labels(Tensor, optional): shape (k, )
pos_iou_thr(float): iou threshold for positive bboxes
neg_iou_thr(float or tuple): iou threshold for negative bboxes
min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
for RPN, it is usually set as 0, for Fast R-CNN,
it is usually set as pos_iou_thr
Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
"""
num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default
assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1)
if overlaps.numel() == 0:
raise ValueError('No gt bbox or proposals')
assert overlaps.size() == (num_bboxes, num_gts)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=1)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
# 2. assign negative: below
if isinstance(neg_iou_thr, float):
assigned_gt_inds[(max_overlaps >= 0)
& (max_overlaps < neg_iou_thr)] = 0
elif isinstance(neg_iou_thr, tuple):
assert len(neg_iou_thr) == 2
assigned_gt_inds[(max_overlaps >= neg_iou_thr[0])
& (max_overlaps < neg_iou_thr[1])] = 0
# 3. assign positive: above positive IoU threshold
pos_inds = max_overlaps >= pos_iou_thr
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
# 4. assign fg: for each gt, proposals with highest IoU
for i in range(num_gts):
if gt_max_overlaps[i] >= min_pos_iou:
assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
if gt_labels is None:
return assigned_gt_inds, argmax_overlaps, max_overlaps
else:
assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0)
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
1]
return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps
def sample_positives(assigned_gt_inds, num_expected, balance_sampling=True):
"""Balance sampling for positive bboxes/anchors
1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg
"""
pos_inds = torch.nonzero(assigned_gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
elif not balance_sampling:
return random_choice(pos_inds, num_expected)
else:
unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu())
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = torch.nonzero(assigned_gt_inds == i.item())
if inds.numel() != 0:
inds = inds.squeeze(1)
else:
continue
if len(inds) > num_per_gt:
inds = random_choice(inds, num_per_gt)
sampled_inds.append(inds)
sampled_inds = torch.cat(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(
list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
extra_inds = torch.from_numpy(extra_inds).to(
assigned_gt_inds.device).long()
sampled_inds = torch.cat([sampled_inds, extra_inds])
elif len(sampled_inds) > num_expected:
sampled_inds = random_choice(sampled_inds, num_expected)
return sampled_inds
def sample_negatives(assigned_gt_inds,
num_expected,
max_overlaps=None,
balance_thr=0,
hard_fraction=0.5):
"""Balance sampling for negative bboxes/anchors
negative samples are split into 2 set: hard(balance_thr <= iou < neg_iou_thr)
and easy(iou < balance_thr), around equal number of bg are sampled
from each set.
"""
neg_inds = torch.nonzero(assigned_gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
elif balance_thr <= 0:
# uniform sampling among all negative samples
return random_choice(neg_inds, num_expected)
else:
assert max_overlaps is not None
max_overlaps = max_overlaps.cpu().numpy()
# balance sampling for negative samples
neg_set = set(neg_inds.cpu().numpy())
easy_set = set(
np.where(
np.logical_and(max_overlaps >= 0,
max_overlaps < balance_thr))[0])
hard_set = set(np.where(max_overlaps >= balance_thr)[0])
easy_neg_inds = list(easy_set & neg_set)
hard_neg_inds = list(hard_set & neg_set)
num_expected_hard = int(num_expected * hard_fraction)
if len(hard_neg_inds) > num_expected_hard:
sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard)
else:
sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
num_expected_easy = num_expected - len(sampled_hard_inds)
if len(easy_neg_inds) > num_expected_easy:
sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy)
else:
sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
sampled_inds = torch.from_numpy(sampled_inds).long().to(
assigned_gt_inds.device)
return sampled_inds
def bbox_sampling(assigned_gt_inds,
num_expected,
pos_fraction,
neg_pos_ub,
pos_balance_sampling=True,
max_overlaps=None,
neg_balance_thr=0,
neg_hard_fraction=0.5):
num_expected_pos = int(num_expected * pos_fraction)
pos_inds = sample_positives(assigned_gt_inds, num_expected_pos,
pos_balance_sampling)
num_sampled_pos = pos_inds.numel()
num_neg_max = int(
neg_pos_ub *
num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
neg_inds = sample_negatives(assigned_gt_inds, num_expected_neg,
max_overlaps, neg_balance_thr,
neg_hard_fraction)
return pos_inds, neg_inds
import mmcv
import numpy as np
import torch
def bbox_transform(proposals, gt, means=[0, 0, 0, 0], stds=[1, 1, 1, 1]):
assert proposals.size() == gt.size()
proposals = proposals.float()
gt = gt.float()
px = (proposals[..., 0] + proposals[..., 2]) * 0.5
py = (proposals[..., 1] + proposals[..., 3]) * 0.5
pw = proposals[..., 2] - proposals[..., 0] + 1.0
ph = proposals[..., 3] - proposals[..., 1] + 1.0
gx = (gt[..., 0] + gt[..., 2]) * 0.5
gy = (gt[..., 1] + gt[..., 3]) * 0.5
gw = gt[..., 2] - gt[..., 0] + 1.0
gh = gt[..., 3] - gt[..., 1] + 1.0
dx = (gx - px) / pw
dy = (gy - py) / ph
dw = torch.log(gw / pw)
dh = torch.log(gh / ph)
deltas = torch.stack([dx, dy, dw, dh], dim=-1)
means = deltas.new_tensor(means).unsqueeze(0)
stds = deltas.new_tensor(stds).unsqueeze(0)
deltas = deltas.sub_(means).div_(stds)
return deltas
def bbox_transform_inv(rois,
deltas,
means=[0, 0, 0, 0],
stds=[1, 1, 1, 1],
max_shape=None,
wh_ratio_clip=16 / 1000):
means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0::4]
dy = denorm_deltas[:, 1::4]
dw = denorm_deltas[:, 2::4]
dh = denorm_deltas[:, 3::4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = dw.clamp(min=-max_ratio, max=max_ratio)
dh = dh.clamp(min=-max_ratio, max=max_ratio)
px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
gw = pw * dw.exp()
gh = ph * dh.exp()
gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx
gy = torch.addcmul(py, 1, ph, dy) # gy = py + ph * dy
x1 = gx - gw * 0.5 + 0.5
y1 = gy - gh * 0.5 + 0.5
x2 = gx + gw * 0.5 - 0.5
y2 = gy + gh * 0.5 - 0.5
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1] - 1)
y1 = y1.clamp(min=0, max=max_shape[0] - 1)
x2 = x2.clamp(min=0, max=max_shape[1] - 1)
y2 = y2.clamp(min=0, max=max_shape[0] - 1)
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
return bboxes
def bbox_flip(bboxes, img_shape):
"""Flip bboxes horizontally
Args:
bboxes(Tensor): shape (..., 4*k)
img_shape(Tensor): image shape
"""
if isinstance(bboxes, torch.Tensor):
assert bboxes.shape[-1] % 4 == 0
flipped = bboxes.clone()
flipped[:, 0::4] = img_shape[1] - bboxes[:, 2::4] - 1
flipped[:, 2::4] = img_shape[1] - bboxes[:, 0::4] - 1
return flipped
elif isinstance(bboxes, np.ndarray):
return mmcv.bbox_flip(bboxes, img_shape)
def bbox_mapping(bboxes, img_shape, flip):
"""Map bboxes from the original image scale to testing scale"""
new_bboxes = bboxes * img_shape[-1]
if flip:
new_bboxes = bbox_flip(new_bboxes, img_shape)
return new_bboxes
def bbox_mapping_back(bboxes, img_shape, flip):
"""Map bboxes from testing scale to original image scale"""
new_bboxes = bbox_flip(bboxes, img_shape) if flip else bboxes
new_bboxes = new_bboxes / img_shape[-1]
return new_bboxes
def bbox2roi(bbox_list):
"""Convert a list of bboxes to roi format.
Args:
bbox_list (Tensor): a list of bboxes corresponding to a list of images
Returns:
Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
"""
rois_list = []
for img_id, bboxes in enumerate(bbox_list):
if bboxes.size(0) > 0:
img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)
rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)
else:
rois = bboxes.new_zeros((0, 5))
rois_list.append(rois)
rois = torch.cat(rois_list, 0)
return rois
def roi2bbox(rois):
bbox_list = []
img_ids = torch.unique(rois[:, 0].cpu(), sorted=True)
for img_id in img_ids:
inds = (rois[:, 0] == img_id.item())
bbox = rois[inds, 1:]
bbox_list.append(bbox)
return bbox_list
from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes, coco_classes, dataset_aliases,
get_classes)
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
plot_iou_recall)
__all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall'
]
import numpy as np
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
"""Calculate the ious between each bbox of bboxes1 and bboxes2.
Args:
bboxes1(ndarray): shape (n, 4)
bboxes2(ndarray): shape (k, 4)
mode(str): iou (intersection over union) or iof (intersection
over foreground)
Returns:
ious(ndarray): shape (n, k)
"""
assert mode in ['iou', 'iof']
bboxes1 = bboxes1.astype(np.float32)
bboxes2 = bboxes2.astype(np.float32)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
ious = np.zeros((rows, cols), dtype=np.float32)
if rows * cols == 0:
return ious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
ious = np.zeros((cols, rows), dtype=np.float32)
exchange = True
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
bboxes1[:, 3] - bboxes1[:, 1] + 1)
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
bboxes2[:, 3] - bboxes2[:, 1] + 1)
for i in range(bboxes1.shape[0]):
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
y_end - y_start + 1, 0)
if mode == 'iou':
union = area1[i] + area2 - overlap
else:
union = area1[i] if not exchange else area2
ious[i, :] = overlap / union
if exchange:
ious = ious.T
return ious
import mmcv
def voc_classes():
return [
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
def imagenet_det_classes():
return [
'accordion', 'airplane', 'ant', 'antelope', 'apple', 'armadillo',
'artichoke', 'axe', 'baby_bed', 'backpack', 'bagel', 'balance_beam',
'banana', 'band_aid', 'banjo', 'baseball', 'basketball', 'bathing_cap',
'beaker', 'bear', 'bee', 'bell_pepper', 'bench', 'bicycle', 'binder',
'bird', 'bookshelf', 'bow_tie', 'bow', 'bowl', 'brassiere', 'burrito',
'bus', 'butterfly', 'camel', 'can_opener', 'car', 'cart', 'cattle',
'cello', 'centipede', 'chain_saw', 'chair', 'chime', 'cocktail_shaker',
'coffee_maker', 'computer_keyboard', 'computer_mouse', 'corkscrew',
'cream', 'croquet_ball', 'crutch', 'cucumber', 'cup_or_mug', 'diaper',
'digital_clock', 'dishwasher', 'dog', 'domestic_cat', 'dragonfly',
'drum', 'dumbbell', 'electric_fan', 'elephant', 'face_powder', 'fig',
'filing_cabinet', 'flower_pot', 'flute', 'fox', 'french_horn', 'frog',
'frying_pan', 'giant_panda', 'goldfish', 'golf_ball', 'golfcart',
'guacamole', 'guitar', 'hair_dryer', 'hair_spray', 'hamburger',
'hammer', 'hamster', 'harmonica', 'harp', 'hat_with_a_wide_brim',
'head_cabbage', 'helmet', 'hippopotamus', 'horizontal_bar', 'horse',
'hotdog', 'iPod', 'isopod', 'jellyfish', 'koala_bear', 'ladle',
'ladybug', 'lamp', 'laptop', 'lemon', 'lion', 'lipstick', 'lizard',
'lobster', 'maillot', 'maraca', 'microphone', 'microwave', 'milk_can',
'miniskirt', 'monkey', 'motorcycle', 'mushroom', 'nail', 'neck_brace',
'oboe', 'orange', 'otter', 'pencil_box', 'pencil_sharpener', 'perfume',
'person', 'piano', 'pineapple', 'ping-pong_ball', 'pitcher', 'pizza',
'plastic_bag', 'plate_rack', 'pomegranate', 'popsicle', 'porcupine',
'power_drill', 'pretzel', 'printer', 'puck', 'punching_bag', 'purse',
'rabbit', 'racket', 'ray', 'red_panda', 'refrigerator',
'remote_control', 'rubber_eraser', 'rugby_ball', 'ruler',
'salt_or_pepper_shaker', 'saxophone', 'scorpion', 'screwdriver',
'seal', 'sheep', 'ski', 'skunk', 'snail', 'snake', 'snowmobile',
'snowplow', 'soap_dispenser', 'soccer_ball', 'sofa', 'spatula',
'squirrel', 'starfish', 'stethoscope', 'stove', 'strainer',
'strawberry', 'stretcher', 'sunglasses', 'swimming_trunks', 'swine',
'syringe', 'table', 'tape_player', 'tennis_ball', 'tick', 'tie',
'tiger', 'toaster', 'traffic_light', 'train', 'trombone', 'trumpet',
'turtle', 'tv_or_monitor', 'unicycle', 'vacuum', 'violin',
'volleyball', 'waffle_iron', 'washer', 'water_bottle', 'watercraft',
'whale', 'wine_bottle', 'zebra'
]
def imagenet_vid_classes():
return [
'airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car',
'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda',
'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit',
'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle',
'watercraft', 'whale', 'zebra'
]
def coco_classes():
return [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
dataset_aliases = {
'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
'coco': ['coco', 'mscoco', 'ms_coco']
}
def get_classes(dataset):
"""Get class names of a dataset."""
alias2name = {}
for name, aliases in dataset_aliases.items():
for alias in aliases:
alias2name[alias] = name
if mmcv.is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_labels()')
else:
raise ValueError('Unrecognized dataset: {}'.format(dataset))
else:
raise TypeError('dataset must a str, but got {}'.format(type(dataset)))
return labels
import numpy as np
from terminaltables import AsciiTable
from .bbox_overlaps import bbox_overlaps
from .class_names import get_classes
def average_precision(recalls, precisions, mode='area'):
"""Calculate average precision (for single or multiple scales).
Args:
recalls(ndarray): shape (num_scales, num_dets) or (num_dets, )
precisions(ndarray): shape (num_scales, num_dets) or (num_dets, )
mode(str): 'area' or '11points', 'area' means calculating the area
under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1]
Returns:
float or ndarray: calculated average precision
"""
no_scale = False
if recalls.ndim == 1:
no_scale = True
recalls = recalls[np.newaxis, :]
precisions = precisions[np.newaxis, :]
assert recalls.shape == precisions.shape and recalls.ndim == 2
num_scales = recalls.shape[0]
ap = np.zeros(num_scales, dtype=np.float32)
if mode == 'area':
zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
ones = np.ones((num_scales, 1), dtype=recalls.dtype)
mrec = np.hstack((zeros, recalls, ones))
mpre = np.hstack((zeros, precisions, zeros))
for i in range(mpre.shape[1] - 1, 0, -1):
mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
for i in range(num_scales):
ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
ap[i] = np.sum(
(mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
elif mode == '11points':
for i in range(num_scales):
for thr in np.arange(0, 1 + 1e-3, 0.1):
precs = precisions[i, recalls[i, :] >= thr]
prec = precs.max() if precs.size > 0 else 0
ap[i] += prec
ap /= 11
else:
raise ValueError(
'Unrecognized mode, only "area" and "11points" are supported')
if no_scale:
ap = ap[0]
return ap
def tpfp_imagenet(det_bboxes,
gt_bboxes,
gt_ignore,
default_iou_thr,
area_ranges=None):
"""Check if detected bboxes are true positive or false positive.
Args:
det_bbox(ndarray): the detected bbox
gt_bboxes(ndarray): ground truth bboxes of this image
gt_ignore(ndarray): indicate if gts are ignored for evaluation or not
default_iou_thr(float): the iou thresholds for medium and large bboxes
area_ranges(list or None): gt bbox area ranges
Returns:
tuple: two arrays (tp, fp) whose elements are 0 and 1
"""
num_dets = det_bboxes.shape[0]
num_gts = gt_bboxes.shape[0]
if area_ranges is None:
area_ranges = [(None, None)]
num_scales = len(area_ranges)
# tp and fp are of shape (num_scales, num_gts), each row is tp or fp
# of a certain scale.
tp = np.zeros((num_scales, num_dets), dtype=np.float32)
fp = np.zeros((num_scales, num_dets), dtype=np.float32)
if gt_bboxes.shape[0] == 0:
if area_ranges == [(None, None)]:
fp[...] = 1
else:
det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * (
det_bboxes[:, 3] - det_bboxes[:, 1] + 1)
for i, (min_area, max_area) in enumerate(area_ranges):
fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
return tp, fp
ious = bbox_overlaps(det_bboxes, gt_bboxes - 1)
gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1
gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1
iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
default_iou_thr)
# sort all detections by scores in descending order
sort_inds = np.argsort(-det_bboxes[:, -1])
for k, (min_area, max_area) in enumerate(area_ranges):
gt_covered = np.zeros(num_gts, dtype=bool)
# if no area range is specified, gt_area_ignore is all False
if min_area is None:
gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool)
else:
gt_areas = gt_w * gt_h
gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
for i in sort_inds:
max_iou = -1
matched_gt = -1
# find best overlapped available gt
for j in range(num_gts):
# different from PASCAL VOC: allow finding other gts if the
# best overlaped ones are already matched by other det bboxes
if gt_covered[j]:
continue
elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou:
max_iou = ious[i, j]
matched_gt = j
# there are 4 cases for a det bbox:
# 1. this det bbox matches a gt, tp = 1, fp = 0
# 2. this det bbox matches an ignored gt, tp = 0, fp = 0
# 3. this det bbox matches no gt and within area range, tp = 0, fp = 1
# 4. this det bbox matches no gt but is beyond area range, tp = 0, fp = 0
if matched_gt >= 0:
gt_covered[matched_gt] = 1
if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]):
tp[k, i] = 1
elif min_area is None:
fp[k, i] = 1
else:
bbox = det_bboxes[i, :4]
area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1)
if area >= min_area and area < max_area:
fp[k, i] = 1
return tp, fp
def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
"""Check if detected bboxes are true positive or false positive.
Args:
det_bbox(ndarray): the detected bbox
gt_bboxes(ndarray): ground truth bboxes of this image
gt_ignore(ndarray): indicate if gts are ignored for evaluation or not
iou_thr(float): the iou thresholds
Returns:
tuple: (tp, fp), two arrays whose elements are 0 and 1
"""
num_dets = det_bboxes.shape[0]
num_gts = gt_bboxes.shape[0]
if area_ranges is None:
area_ranges = [(None, None)]
num_scales = len(area_ranges)
# tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
# a certain scale
tp = np.zeros((num_scales, num_dets), dtype=np.float32)
fp = np.zeros((num_scales, num_dets), dtype=np.float32)
# if there is no gt bboxes in this image, then all det bboxes
# within area range are false positives
if gt_bboxes.shape[0] == 0:
if area_ranges == [(None, None)]:
fp[...] = 1
else:
det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * (
det_bboxes[:, 3] - det_bboxes[:, 1] + 1)
for i, (min_area, max_area) in enumerate(area_ranges):
fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
return tp, fp
ious = bbox_overlaps(det_bboxes, gt_bboxes)
ious_max = ious.max(axis=1)
ious_argmax = ious.argmax(axis=1)
sort_inds = np.argsort(-det_bboxes[:, -1])
for k, (min_area, max_area) in enumerate(area_ranges):
gt_covered = np.zeros(num_gts, dtype=bool)
# if no area range is specified, gt_area_ignore is all False
if min_area is None:
gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool)
else:
gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
for i in sort_inds:
if ious_max[i] >= iou_thr:
matched_gt = ious_argmax[i]
if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]):
if not gt_covered[matched_gt]:
gt_covered[matched_gt] = True
tp[k, i] = 1
else:
fp[k, i] = 1
# otherwise ignore this detected bbox, tp = 0, fp = 0
elif min_area is None:
fp[k, i] = 1
else:
bbox = det_bboxes[i, :4]
area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1)
if area >= min_area and area < max_area:
fp[k, i] = 1
return tp, fp
def get_cls_results(det_results, gt_bboxes, gt_labels, gt_ignore, class_id):
"""Get det results and gt information of a certain class."""
cls_dets = [det[class_id]
for det in det_results] # det bboxes of this class
cls_gts = [] # gt bboxes of this class
cls_gt_ignore = []
for j in range(len(gt_bboxes)):
gt_bbox = gt_bboxes[j]
cls_inds = (gt_labels[j] == class_id + 1)
cls_gt = gt_bbox[cls_inds, :] if gt_bbox.shape[0] > 0 else gt_bbox
cls_gts.append(cls_gt)
if gt_ignore is None:
cls_gt_ignore.append(np.zeros(cls_gt.shape[0], dtype=np.int32))
else:
cls_gt_ignore.append(gt_ignore[j][cls_inds])
return cls_dets, cls_gts, cls_gt_ignore
def eval_map(det_results,
gt_bboxes,
gt_labels,
gt_ignore=None,
scale_ranges=None,
iou_thr=0.5,
dataset=None,
print_summary=True):
"""Evaluate mAP of a dataset.
Args:
det_results(list): a list of list, [[cls1_det, cls2_det, ...], ...]
gt_bboxes(list): ground truth bboxes of each image, a list of K*4 array
gt_labels(list): ground truth labels of each image, a list of K array
gt_ignore(list): gt ignore indicators of each image, a list of K array
scale_ranges(list, optional): [(min1, max1), (min2, max2), ...]
iou_thr(float): IoU threshold
dataset(None or str): dataset name, there are minor differences in
metrics for different datsets, e.g. "voc07", "imagenet_det", etc.
print_summary(bool): whether to print the mAP summary
Returns:
tuple: (mAP, [dict, dict, ...])
"""
assert len(det_results) == len(gt_bboxes) == len(gt_labels)
if gt_ignore is not None:
assert len(gt_ignore) == len(gt_labels)
for i in range(len(gt_ignore)):
assert len(gt_labels[i]) == len(gt_ignore[i])
area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
if scale_ranges is not None else None)
num_scales = len(scale_ranges) if scale_ranges is not None else 1
eval_results = []
num_classes = len(det_results[0]) # positive class num
gt_labels = [
label if label.ndim == 1 else label[:, 0] for label in gt_labels
]
for i in range(num_classes):
# get gt and det bboxes of this class
cls_dets, cls_gts, cls_gt_ignore = get_cls_results(
det_results, gt_bboxes, gt_labels, gt_ignore, i)
# calculate tp and fp for each image
tpfp_func = (tpfp_imagenet
if dataset in ['det', 'vid'] else tpfp_default)
tpfp = [
tpfp_func(cls_dets[j], cls_gts[j], cls_gt_ignore[j], iou_thr,
area_ranges) for j in range(len(cls_dets))
]
tp, fp = tuple(zip(*tpfp))
# calculate gt number of each scale, gts ignored or beyond scale are not counted
num_gts = np.zeros(num_scales, dtype=int)
for j, bbox in enumerate(cls_gts):
if area_ranges is None:
num_gts[0] += np.sum(np.logical_not(cls_gt_ignore[j]))
else:
gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * (
bbox[:, 3] - bbox[:, 1] + 1)
for k, (min_area, max_area) in enumerate(area_ranges):
num_gts[k] += np.sum(
np.logical_not(cls_gt_ignore[j]) &
(gt_areas >= min_area) & (gt_areas < max_area))
# sort all det bboxes by score, also sort tp and fp
cls_dets = np.vstack(cls_dets)
num_dets = cls_dets.shape[0]
sort_inds = np.argsort(-cls_dets[:, -1])
tp = np.hstack(tp)[:, sort_inds]
fp = np.hstack(fp)[:, sort_inds]
# calculate recall and precision with tp and fp
tp = np.cumsum(tp, axis=1)
fp = np.cumsum(fp, axis=1)
eps = np.finfo(np.float32).eps
recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
precisions = tp / np.maximum((tp + fp), eps)
# calculate AP
if scale_ranges is None:
recalls = recalls[0, :]
precisions = precisions[0, :]
num_gts = num_gts.item()
mode = 'area' if dataset != 'voc07' else '11points'
ap = average_precision(recalls, precisions, mode)
eval_results.append({
'num_gts': num_gts,
'num_dets': num_dets,
'recall': recalls,
'precision': precisions,
'ap': ap
})
if scale_ranges is not None:
# shape (num_classes, num_scales)
all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
all_num_gts = np.vstack(
[cls_result['num_gts'] for cls_result in eval_results])
mean_ap = [
all_ap[all_num_gts[:, i] > 0, i].mean()
if np.any(all_num_gts[:, i] > 0) else 0.0
for i in range(num_scales)
]
else:
aps = []
for cls_result in eval_results:
if cls_result['num_gts'] > 0:
aps.append(cls_result['ap'])
mean_ap = np.array(aps).mean().item() if aps else 0.0
if print_summary:
print_map_summary(mean_ap, eval_results, dataset)
return mean_ap, eval_results
def print_map_summary(mean_ap, results, dataset=None):
"""Print mAP and results of each class.
Args:
mean_ap(float): calculated from `eval_map`
results(list): calculated from `eval_map`
dataset(None or str or list): dataset name.
"""
num_scales = len(results[0]['ap']) if isinstance(results[0]['ap'],
np.ndarray) else 1
num_classes = len(results)
recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
precisions = np.zeros((num_scales, num_classes), dtype=np.float32)
aps = np.zeros((num_scales, num_classes), dtype=np.float32)
num_gts = np.zeros((num_scales, num_classes), dtype=int)
for i, cls_result in enumerate(results):
if cls_result['recall'].size > 0:
recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
precisions[:, i] = np.array(
cls_result['precision'], ndmin=2)[:, -1]
aps[:, i] = cls_result['ap']
num_gts[:, i] = cls_result['num_gts']
if dataset is None:
label_names = [str(i) for i in range(1, num_classes + 1)]
else:
label_names = get_classes(dataset)
if not isinstance(mean_ap, list):
mean_ap = [mean_ap]
header = ['class', 'gts', 'dets', 'recall', 'precision', 'ap']
for i in range(num_scales):
table_data = [header]
for j in range(num_classes):
row_data = [
label_names[j], num_gts[i, j], results[j]['num_dets'],
'{:.3f}'.format(recalls[i, j]), '{:.3f}'.format(
precisions[i, j]), '{:.3f}'.format(aps[i, j])
]
table_data.append(row_data)
table_data.append(['mAP', '', '', '', '', '{:.3f}'.format(mean_ap[i])])
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print(table.table)
import numpy as np
from terminaltables import AsciiTable
from .bbox_overlaps import bbox_overlaps
def _recalls(all_ious, proposal_nums, thrs):
img_num = all_ious.shape[0]
total_gt_num = sum([ious.shape[0] for ious in all_ious])
_ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32)
for k, proposal_num in enumerate(proposal_nums):
tmp_ious = np.zeros(0)
for i in range(img_num):
ious = all_ious[i][:, :proposal_num].copy()
gt_ious = np.zeros((ious.shape[0]))
if ious.size == 0:
tmp_ious = np.hstack((tmp_ious, gt_ious))
continue
for j in range(ious.shape[0]):
gt_max_overlaps = ious.argmax(axis=1)
max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps]
gt_idx = max_ious.argmax()
gt_ious[j] = max_ious[gt_idx]
box_idx = gt_max_overlaps[gt_idx]
ious[gt_idx, :] = -1
ious[:, box_idx] = -1
tmp_ious = np.hstack((tmp_ious, gt_ious))
_ious[k, :] = tmp_ious
_ious = np.fliplr(np.sort(_ious, axis=1))
recalls = np.zeros((proposal_nums.size, thrs.size))
for i, thr in enumerate(thrs):
recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num)
return recalls
def set_recall_param(proposal_nums, iou_thrs):
"""Check proposal_nums and iou_thrs and set correct format.
"""
if isinstance(proposal_nums, list):
_proposal_nums = np.array(proposal_nums)
elif isinstance(proposal_nums, int):
_proposal_nums = np.array([proposal_nums])
else:
_proposal_nums = proposal_nums
if iou_thrs is None:
_iou_thrs = np.array([0.5])
elif isinstance(iou_thrs, list):
_iou_thrs = np.array(iou_thrs)
elif isinstance(iou_thrs, float):
_iou_thrs = np.array([iou_thrs])
else:
_iou_thrs = iou_thrs
return _proposal_nums, _iou_thrs
def eval_recalls(gts,
proposals,
proposal_nums=None,
iou_thrs=None,
print_summary=True):
"""Calculate recalls.
Args:
gts(list or ndarray): a list of arrays of shape (n, 4)
proposals(list or ndarray): a list of arrays of shape (k, 4) or (k, 5)
proposal_nums(int or list of int or ndarray): top N proposals
thrs(float or list or ndarray): iou thresholds
Returns:
ndarray: recalls of different ious and proposal nums
"""
img_num = len(gts)
assert img_num == len(proposals)
proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs)
all_ious = []
for i in range(img_num):
if proposals[i].ndim == 2 and proposals[i].shape[1] == 5:
scores = proposals[i][:, 4]
sort_idx = np.argsort(scores)[::-1]
img_proposal = proposals[i][sort_idx, :]
else:
img_proposal = proposals[i]
prop_num = min(img_proposal.shape[0], proposal_nums[-1])
if gts[i] is None or gts[i].shape[0] == 0:
ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32)
else:
ious = bbox_overlaps(gts[i], img_proposal[:prop_num, :4])
all_ious.append(ious)
all_ious = np.array(all_ious)
recalls = _recalls(all_ious, proposal_nums, iou_thrs)
if print_summary:
print_recall_summary(recalls, proposal_nums, iou_thrs)
return recalls
def print_recall_summary(recalls,
proposal_nums,
iou_thrs,
row_idxs=None,
col_idxs=None):
"""Print recalls in a table.
Args:
recalls(ndarray): calculated from `bbox_recalls`
proposal_nums(ndarray or list): top N proposals
iou_thrs(ndarray or list): iou thresholds
row_idxs(ndarray): which rows(proposal nums) to print
col_idxs(ndarray): which cols(iou thresholds) to print
"""
proposal_nums = np.array(proposal_nums, dtype=np.int32)
iou_thrs = np.array(iou_thrs)
if row_idxs is None:
row_idxs = np.arange(proposal_nums.size)
if col_idxs is None:
col_idxs = np.arange(iou_thrs.size)
row_header = [''] + iou_thrs[col_idxs].tolist()
table_data = [row_header]
for i, num in enumerate(proposal_nums[row_idxs]):
row = [
'{:.3f}'.format(val)
for val in recalls[row_idxs[i], col_idxs].tolist()
]
row.insert(0, num)
table_data.append(row)
table = AsciiTable(table_data)
print(table.table)
def plot_num_recall(recalls, proposal_nums):
"""Plot Proposal_num-Recalls curve.
Args:
recalls(ndarray or list): shape (k,)
proposal_nums(ndarray or list): same shape as `recalls`
"""
if isinstance(proposal_nums, np.ndarray):
_proposal_nums = proposal_nums.tolist()
else:
_proposal_nums = proposal_nums
if isinstance(recalls, np.ndarray):
_recalls = recalls.tolist()
else:
_recalls = recalls
import matplotlib.pyplot as plt
f = plt.figure()
plt.plot([0] + _proposal_nums, [0] + _recalls)
plt.xlabel('Proposal num')
plt.ylabel('Recall')
plt.axis([0, proposal_nums.max(), 0, 1])
f.show()
def plot_iou_recall(recalls, iou_thrs):
"""Plot IoU-Recalls curve.
Args:
recalls(ndarray or list): shape (k,)
iou_thrs(ndarray or list): same shape as `recalls`
"""
if isinstance(iou_thrs, np.ndarray):
_iou_thrs = iou_thrs.tolist()
else:
_iou_thrs = iou_thrs
if isinstance(recalls, np.ndarray):
_recalls = recalls.tolist()
else:
_recalls = recalls
import matplotlib.pyplot as plt
f = plt.figure()
plt.plot(_iou_thrs + [1.0], _recalls + [0.])
plt.xlabel('IoU')
plt.ylabel('Recall')
plt.axis([iou_thrs.min(), 1, 0, 1])
f.show()
import os
import os.path as osp
import shutil
import time
import mmcv
import numpy as np
import torch
from mmcv.torchpack import Hook
from mmdet import collate, scatter
from pycocotools.cocoeval import COCOeval
from .eval import eval_recalls
class EmptyCacheHook(Hook):
def before_epoch(self, runner):
torch.cuda.empty_cache()
def after_epoch(self, runner):
torch.cuda.empty_cache()
class DistEvalHook(Hook):
def __init__(self, dataset, interval=1):
self.dataset = dataset
self.interval = interval
self.lock_dir = None
def _barrier(self, rank, world_size):
"""Due to some issues with `torch.distributed.barrier()`, we have to
implement this ugly barrier function.
"""
if rank == 0:
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
while not (osp.exists(tmp)):
time.sleep(1)
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
os.remove(tmp)
else:
tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank))
mmcv.dump([], tmp)
while osp.exists(tmp):
time.sleep(1)
def before_run(self, runner):
self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook')
if runner.rank == 0:
if osp.exists(self.lock_dir):
shutil.rmtree(self.lock_dir)
mmcv.mkdir_or_exist(self.lock_dir)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
runner.model.eval()
results = [None for _ in range(len(self.dataset))]
prog_bar = mmcv.ProgressBar(len(self.dataset))
for idx in range(runner.rank, len(self.dataset), runner.world_size):
data = self.dataset[idx]
device_id = torch.cuda.current_device()
imgs_data = tuple(
scatter(collate([data], samples_per_gpu=1), [device_id])[0])
# compute output
with torch.no_grad():
result = runner.model(
*imgs_data,
return_loss=False,
return_bboxes=True,
rescale=True)
results[idx] = result
batch_size = runner.world_size
for _ in range(batch_size):
prog_bar.update()
if runner.rank == 0:
print('\n')
self._barrier(runner.rank, runner.world_size)
for i in range(1, runner.world_size):
tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i))
tmp_results = mmcv.load(tmp_file)
for idx in range(i, len(results), runner.world_size):
results[idx] = tmp_results[idx]
os.remove(tmp_file)
self.evaluate(runner, results)
else:
tmp_file = osp.join(runner.work_dir,
'temp_{}.pkl'.format(runner.rank))
mmcv.dump(results, tmp_file)
self._barrier(runner.rank, runner.world_size)
self._barrier(runner.rank, runner.world_size)
def evaluate(self):
raise NotImplementedError
class CocoEvalMixin(object):
def _xyxy2xywh(self, bbox):
_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0] + 1,
_bbox[3] - _bbox[1] + 1,
]
def det2json(self, dataset, results):
json_results = []
for idx in range(len(dataset)):
img_id = dataset.img_ids[idx]
result = results[idx]
for label in range(len(result)):
bboxes = result[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self._xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = dataset.cat_ids[label]
json_results.append(data)
return json_results
def segm2json(self, dataset, results):
json_results = []
for idx in range(len(dataset)):
img_id = dataset.img_ids[idx]
det, seg = results[idx]
for label in range(len(det)):
bboxes = det[label]
segms = seg[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self._xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = dataset.cat_ids[label]
segms[i]['counts'] = segms[i]['counts'].decode()
data['segmentation'] = segms[i]
json_results.append(data)
return json_results
def proposal2json(self, dataset, results):
json_results = []
for idx in range(len(dataset)):
img_id = dataset.img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self._xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results
def results2json(self, dataset, results, out_file):
if isinstance(results[0], list):
json_results = self.det2json(dataset, results)
elif isinstance(results[0], tuple):
json_results = self.segm2json(dataset, results)
elif isinstance(results[0], np.ndarray):
json_results = self.proposal2json(dataset, results)
else:
raise TypeError('invalid type of results')
mmcv.dump(json_results, out_file, file_format='json')
class DistEvalRecallHook(DistEvalHook):
def __init__(self,
dataset,
proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)):
super(DistEvalRecallHook, self).__init__(dataset)
self.proposal_nums = np.array(proposal_nums, dtype=np.int32)
self.iou_thrs = np.array(iou_thrs, dtype=np.float32)
def evaluate(self, runner, results):
# official coco evaluation is too slow, here we use our own
# implementation, which may get slightly different results
gt_bboxes = []
for i in range(len(self.dataset)):
img_id = self.dataset.img_ids[i]
ann_ids = self.dataset.coco.getAnnIds(imgIds=img_id)
ann_info = self.dataset.coco.loadAnns(ann_ids)
if len(ann_info) == 0:
gt_bboxes.append(np.zeros((0, 4)))
continue
bboxes = []
for ann in ann_info:
if ann.get('ignore', False) or ann['iscrowd']:
continue
x1, y1, w, h = ann['bbox']
bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1])
bboxes = np.array(bboxes, dtype=np.float32)
if bboxes.shape[0] == 0:
bboxes = np.zeros((0, 4))
gt_bboxes.append(bboxes)
recalls = eval_recalls(
gt_bboxes,
results,
self.proposal_nums,
self.iou_thrs,
print_summary=False)
ar = recalls.mean(axis=1)
for i, num in enumerate(self.proposal_nums):
runner.log_buffer.output['AR@{}'.format(num)] = ar[i]
runner.log_buffer.ready = True
class CocoDistEvalmAPHook(DistEvalHook, CocoEvalMixin):
def evaluate(self, runner, results):
tmp_file = osp.join(runner.work_dir, 'temp_0.json')
self.results2json(self.dataset, results, tmp_file)
res_types = ['bbox', 'segm'] if runner.model.with_mask else ['bbox']
cocoGt = self.dataset.coco
cocoDt = cocoGt.loadRes(tmp_file)
imgIds = cocoGt.getImgIds()
for res_type in res_types:
iou_type = res_type
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
cocoEval.params.imgIds = imgIds
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
field = '{}_mAP'.format(res_type)
runner.log_buffer.output[field] = cocoEval.stats[0]
runner.log_buffer.ready = True
os.remove(tmp_file)
class CocoDistCascadeEvalmAPHook(CocoDistEvalmAPHook):
def evaluate(self, runner, results):
results = [res[-1] for res in results]
super(CocoDistCascadeEvalmAPHook, self).evaluate(runner, results)
from .segms import (flip_segms, polys_to_mask, mask_to_bbox,
polys_to_mask_wrt_box, polys_to_boxes, rle_mask_voting,
rle_mask_nms, rle_masks_to_boxes)
from .utils import split_combined_gt_polys
__all__ = [
'flip_segms', 'polys_to_mask', 'mask_to_bbox', 'polys_to_mask_wrt_box',
'polys_to_boxes', 'rle_mask_voting', 'rle_mask_nms', 'rle_masks_to_boxes',
'split_combined_gt_polys'
]
# This file is copied from Detectron.
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Functions for interacting with segmentation masks in the COCO format.
The following terms are used in this module
mask: a binary mask encoded as a 2D numpy array
segm: a segmentation mask in one of the two COCO formats (polygon or RLE)
polygon: COCO's polygon format
RLE: COCO's run length encoding format
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import pycocotools.mask as mask_util
def flip_segms(segms, height, width):
"""Left/right flip each mask in a list of masks."""
def _flip_poly(poly, width):
flipped_poly = np.array(poly)
flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
return flipped_poly.tolist()
def _flip_rle(rle, height, width):
if 'counts' in rle and type(rle['counts']) == list:
# Magic RLE format handling painfully discovered by looking at the
# COCO API showAnns function.
rle = mask_util.frPyObjects([rle], height, width)
mask = mask_util.decode(rle)
mask = mask[:, ::-1, :]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
flipped_segms = []
for segm in segms:
if type(segm) == list:
# Polygon format
flipped_segms.append([_flip_poly(poly, width) for poly in segm])
else:
# RLE format
assert type(segm) == dict
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
def polys_to_mask(polygons, height, width):
"""Convert from the COCO polygon segmentation format to a binary mask
encoded as a 2D array of data type numpy.float32. The polygon segmentation
is understood to be enclosed inside a height x width image. The resulting
mask is therefore of shape (height, width).
"""
rle = mask_util.frPyObjects(polygons, height, width)
mask = np.array(mask_util.decode(rle), dtype=np.float32)
# Flatten in case polygons was a list
mask = np.sum(mask, axis=2)
mask = np.array(mask > 0, dtype=np.float32)
return mask
def mask_to_bbox(mask):
"""Compute the tight bounding box of a binary mask."""
xs = np.where(np.sum(mask, axis=0) > 0)[0]
ys = np.where(np.sum(mask, axis=1) > 0)[0]
if len(xs) == 0 or len(ys) == 0:
return None
x0 = xs[0]
x1 = xs[-1]
y0 = ys[0]
y1 = ys[-1]
return np.array((x0, y0, x1, y1), dtype=np.float32)
def polys_to_mask_wrt_box(polygons, box, M):
"""Convert from the COCO polygon segmentation format to a binary mask
encoded as a 2D array of data type numpy.float32. The polygon segmentation
is understood to be enclosed in the given box and rasterized to an M x M
mask. The resulting mask is therefore of shape (M, M).
"""
w = box[2] - box[0]
h = box[3] - box[1]
w = np.maximum(w, 1)
h = np.maximum(h, 1)
polygons_norm = []
for poly in polygons:
p = np.array(poly, dtype=np.float32)
p[0::2] = (p[0::2] - box[0]) * M / w
p[1::2] = (p[1::2] - box[1]) * M / h
polygons_norm.append(p)
rle = mask_util.frPyObjects(polygons_norm, M, M)
mask = np.array(mask_util.decode(rle), dtype=np.float32)
# Flatten in case polygons was a list
mask = np.sum(mask, axis=2)
mask = np.array(mask > 0, dtype=np.float32)
return mask
def polys_to_boxes(polys):
"""Convert a list of polygons into an array of tight bounding boxes."""
boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32)
for i in range(len(polys)):
poly = polys[i]
x0 = min(min(p[::2]) for p in poly)
x1 = max(max(p[::2]) for p in poly)
y0 = min(min(p[1::2]) for p in poly)
y1 = max(max(p[1::2]) for p in poly)
boxes_from_polys[i, :] = [x0, y0, x1, y1]
return boxes_from_polys
def rle_mask_voting(top_masks,
all_masks,
all_dets,
iou_thresh,
binarize_thresh,
method='AVG'):
"""Returns new masks (in correspondence with `top_masks`) by combining
multiple overlapping masks coming from the pool of `all_masks`. Two methods
for combining masks are supported: 'AVG' uses a weighted average of
overlapping mask pixels; 'UNION' takes the union of all mask pixels.
"""
if len(top_masks) == 0:
return
all_not_crowd = [False] * len(all_masks)
top_to_all_overlaps = mask_util.iou(top_masks, all_masks, all_not_crowd)
decoded_all_masks = [
np.array(mask_util.decode(rle), dtype=np.float32) for rle in all_masks
]
decoded_top_masks = [
np.array(mask_util.decode(rle), dtype=np.float32) for rle in top_masks
]
all_boxes = all_dets[:, :4].astype(np.int32)
all_scores = all_dets[:, 4]
# Fill box support with weights
mask_shape = decoded_all_masks[0].shape
mask_weights = np.zeros((len(all_masks), mask_shape[0], mask_shape[1]))
for k in range(len(all_masks)):
ref_box = all_boxes[k]
x_0 = max(ref_box[0], 0)
x_1 = min(ref_box[2] + 1, mask_shape[1])
y_0 = max(ref_box[1], 0)
y_1 = min(ref_box[3] + 1, mask_shape[0])
mask_weights[k, y_0:y_1, x_0:x_1] = all_scores[k]
mask_weights = np.maximum(mask_weights, 1e-5)
top_segms_out = []
for k in range(len(top_masks)):
# Corner case of empty mask
if decoded_top_masks[k].sum() == 0:
top_segms_out.append(top_masks[k])
continue
inds_to_vote = np.where(top_to_all_overlaps[k] >= iou_thresh)[0]
# Only matches itself
if len(inds_to_vote) == 1:
top_segms_out.append(top_masks[k])
continue
masks_to_vote = [decoded_all_masks[i] for i in inds_to_vote]
if method == 'AVG':
ws = mask_weights[inds_to_vote]
soft_mask = np.average(masks_to_vote, axis=0, weights=ws)
mask = np.array(soft_mask > binarize_thresh, dtype=np.uint8)
elif method == 'UNION':
# Any pixel that's on joins the mask
soft_mask = np.sum(masks_to_vote, axis=0)
mask = np.array(soft_mask > 1e-5, dtype=np.uint8)
else:
raise NotImplementedError('Method {} is unknown'.format(method))
rle = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
top_segms_out.append(rle)
return top_segms_out
def rle_mask_nms(masks, dets, thresh, mode='IOU'):
"""Performs greedy non-maximum suppression based on an overlap measurement
between masks. The type of measurement is determined by `mode` and can be
either 'IOU' (standard intersection over union) or 'IOMA' (intersection over
mininum area).
"""
if len(masks) == 0:
return []
if len(masks) == 1:
return [0]
if mode == 'IOU':
# Computes ious[m1, m2] = area(intersect(m1, m2)) / area(union(m1, m2))
all_not_crowds = [False] * len(masks)
ious = mask_util.iou(masks, masks, all_not_crowds)
elif mode == 'IOMA':
# Computes ious[m1, m2] = area(intersect(m1, m2)) / min(area(m1), area(m2))
all_crowds = [True] * len(masks)
# ious[m1, m2] = area(intersect(m1, m2)) / area(m2)
ious = mask_util.iou(masks, masks, all_crowds)
# ... = max(area(intersect(m1, m2)) / area(m2),
# area(intersect(m2, m1)) / area(m1))
ious = np.maximum(ious, ious.transpose())
elif mode == 'CONTAINMENT':
# Computes ious[m1, m2] = area(intersect(m1, m2)) / area(m2)
# Which measures how much m2 is contained inside m1
all_crowds = [True] * len(masks)
ious = mask_util.iou(masks, masks, all_crowds)
else:
raise NotImplementedError('Mode {} is unknown'.format(mode))
scores = dets[:, 4]
order = np.argsort(-scores)
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
ovr = ious[i, order[1:]]
inds_to_keep = np.where(ovr <= thresh)[0]
order = order[inds_to_keep + 1]
return keep
def rle_masks_to_boxes(masks):
"""Computes the bounding box of each mask in a list of RLE encoded masks."""
if len(masks) == 0:
return []
decoded_masks = [
np.array(mask_util.decode(rle), dtype=np.float32) for rle in masks
]
def get_bounds(flat_mask):
inds = np.where(flat_mask > 0)[0]
return inds.min(), inds.max()
boxes = np.zeros((len(decoded_masks), 4))
keep = [True] * len(decoded_masks)
for i, mask in enumerate(decoded_masks):
if mask.sum() == 0:
keep[i] = False
continue
flat_mask = mask.sum(axis=0)
x0, x1 = get_bounds(flat_mask)
flat_mask = mask.sum(axis=1)
y0, y1 = get_bounds(flat_mask)
boxes[i, :] = (x0, y0, x1, y1)
return boxes, np.where(keep)[0]
import cvbase as cvb
import numpy as np
import pycocotools.mask as mask_utils
import mmcv
def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
"""Split the combined 1-D polys into masks.
A mask is represented as a list of polys, and a poly is represented as
a 1-D array. In dataset, all masks are concatenated into a single 1-D
tensor. Here we need to split the tensor into original representations.
Args:
gt_polys (list): a list (length = image num) of 1-D tensors
gt_poly_lens (list): a list (length = image num) of poly length
num_polys_per_mask (list): a list (length = image num) of poly number
of each mask
Returns:
list: a list (length = image num) of list (length = mask num) of
list (length = poly num) of numpy array
"""
mask_polys_list = []
for img_id in range(len(gt_polys)):
gt_polys_single = gt_polys[img_id].cpu().numpy()
gt_polys_lens_single = gt_poly_lens[img_id].cpu().numpy().tolist()
num_polys_per_mask_single = num_polys_per_mask[
img_id].cpu().numpy().tolist()
split_gt_polys = mmcv.slice_list(gt_polys_single, gt_polys_lens_single)
mask_polys = mmcv.slice_list(split_gt_polys, num_polys_per_mask_single)
mask_polys_list.append(mask_polys)
return mask_polys_list
from .bbox_nms import multiclass_nms
from .merge_augs import (merge_aug_proposals, merge_aug_bboxes,
merge_aug_scores, merge_aug_masks)
__all__ = [
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
'merge_aug_scores', 'merge_aug_masks'
]
import torch
from mmdet.ops import nms
def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
"""NMS for multi-class bboxes.
Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class)
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_thr (float): NMS IoU threshold
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept.
Returns:
tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels
are 0-based.
"""
num_classes = multi_scores.shape[1]
bboxes, labels = [], []
for i in range(1, num_classes):
cls_inds = multi_scores[:, i] > score_thr
if not cls_inds.any():
continue
# get bboxes and scores of this class
if multi_bboxes.shape[1] == 4:
_bboxes = multi_bboxes[cls_inds, :]
else:
_bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]
_scores = multi_scores[cls_inds, i]
cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
# perform nms
nms_keep = nms(cls_dets, nms_thr)
cls_dets = cls_dets[nms_keep, :]
cls_labels = multi_bboxes.new_full(
(len(nms_keep), ), i - 1, dtype=torch.long)
bboxes.append(cls_dets)
labels.append(cls_labels)
if bboxes:
bboxes = torch.cat(bboxes)
labels = torch.cat(labels)
if bboxes.shape[0] > max_num:
_, inds = bboxes[:, -1].sort(descending=True)
inds = inds[:max_num]
bboxes = bboxes[inds]
labels = labels[inds]
else:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
return bboxes, labels
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment