diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py index cd227ad0665ce705a79a3a5328d2fbba2155b114..e7926858ef5db381357423a0eb9e9f7dee94644b 100644 --- a/mmdet/core/anchor/anchor_generator.py +++ b/mmdet/core/anchor/anchor_generator.py @@ -8,10 +8,10 @@ class AnchorGenerator(object): >>> self = AnchorGenerator(9, [1.], [1.]) >>> all_anchors = self.grid_anchors((2, 2), device='cpu') >>> print(all_anchors) - tensor([[ 0., 0., 8., 8.], - [16., 0., 24., 8.], - [ 0., 16., 8., 24.], - [16., 16., 24., 24.]]) + tensor([[-4.5000, -4.5000, 4.5000, 4.5000], + [11.5000, -4.5000, 20.5000, 4.5000], + [-4.5000, 11.5000, 4.5000, 20.5000], + [11.5000, 11.5000, 20.5000, 20.5000]]) """ def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): @@ -30,8 +30,8 @@ class AnchorGenerator(object): w = self.base_size h = self.base_size if self.ctr is None: - x_ctr = 0.5 * (w - 1) - y_ctr = 0.5 * (h - 1) + x_ctr = 0. + y_ctr = 0. else: x_ctr, y_ctr = self.ctr @@ -44,14 +44,13 @@ class AnchorGenerator(object): ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1) hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1) - # yapf: disable - base_anchors = torch.stack( - [ - x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), - x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) - ], - dim=-1).round() - # yapf: enable + # use float anchor and the anchor's center is aligned with the + # pixel center + base_anchors = [ + x_ctr - 0.5 * ws, y_ctr - 0.5 * hs, x_ctr + 0.5 * ws, + y_ctr + 0.5 * hs + ] + base_anchors = torch.stack(base_anchors, dim=-1) return base_anchors diff --git a/mmdet/core/anchor/guided_anchor_target.py b/mmdet/core/anchor/guided_anchor_target.py index 21162eb9e894993626d2f24db853ebe6681b392a..45d0177b089b350c2eb22637db817036ccf55962 100644 --- a/mmdet/core/anchor/guided_anchor_target.py +++ b/mmdet/core/anchor/guided_anchor_target.py @@ -22,10 +22,10 @@ def calc_region(bbox, ratio, featmap_size=None): x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long() y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long() if featmap_size is not None: - x1 = x1.clamp(min=0, max=featmap_size[1] - 1) - y1 = y1.clamp(min=0, max=featmap_size[0] - 1) - x2 = x2.clamp(min=0, max=featmap_size[1] - 1) - y2 = y2.clamp(min=0, max=featmap_size[0] - 1) + x1 = x1.clamp(min=0, max=featmap_size[1]) + y1 = y1.clamp(min=0, max=featmap_size[0]) + x2 = x2.clamp(min=0, max=featmap_size[1]) + y2 = y2.clamp(min=0, max=featmap_size[0]) return (x1, y1, x2, y2) @@ -76,8 +76,8 @@ def ga_loc_target(gt_bboxes_list, all_ignore_map.append(ignore_map) for img_id in range(img_per_gpu): gt_bboxes = gt_bboxes_list[img_id] - scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * - (gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)) + scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * + (gt_bboxes[:, 3] - gt_bboxes[:, 1])) min_anchor_size = scale.new_full( (1, ), float(anchor_scale * anchor_strides[0])) # assign gt bboxes to different feature levels w.r.t. their scales diff --git a/mmdet/core/bbox/geometry.py b/mmdet/core/bbox/geometry.py index ff7c5d4fa0ddd6417fdf6ed835884e2c38271624..6fd791ed6e60a3aa6a0de9bc36305806844f4314 100644 --- a/mmdet/core/bbox/geometry.py +++ b/mmdet/core/bbox/geometry.py @@ -30,8 +30,8 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): >>> [10, 10, 20, 20], >>> ]) >>> bbox_overlaps(bboxes1, bboxes2) - tensor([[0.5238, 0.0500, 0.0041], - [0.0323, 0.0452, 1.0000], + tensor([[0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 0.0000]]) Example: @@ -58,14 +58,14 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2] rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2] - wh = (rb - lt + 1).clamp(min=0) # [rows, 2] + wh = (rb - lt).clamp(min=0) # [rows, 2] overlap = wh[:, 0] * wh[:, 1] - area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( - bboxes1[:, 3] - bboxes1[:, 1] + 1) + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * ( + bboxes1[:, 3] - bboxes1[:, 1]) if mode == 'iou': - area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( - bboxes2[:, 3] - bboxes2[:, 1] + 1) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * ( + bboxes2[:, 3] - bboxes2[:, 1]) ious = overlap / (area1 + area2 - overlap) else: ious = overlap / area1 @@ -73,14 +73,14 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2] rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2] - wh = (rb - lt + 1).clamp(min=0) # [rows, cols, 2] + wh = (rb - lt).clamp(min=0) # [rows, cols, 2] overlap = wh[:, :, 0] * wh[:, :, 1] - area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( - bboxes1[:, 3] - bboxes1[:, 1] + 1) + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * ( + bboxes1[:, 3] - bboxes1[:, 1]) if mode == 'iou': - area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( - bboxes2[:, 3] - bboxes2[:, 1] + 1) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * ( + bboxes2[:, 3] - bboxes2[:, 1]) ious = overlap / (area1[:, None] + area2 - overlap) else: ious = overlap / (area1[:, None]) diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 58aa13ad5ce1eb289c31f828b3ba2b2dcc3663c0..288d9e79fa132a9dab63227172f276fa14f0cdb5 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -10,13 +10,13 @@ def bbox2delta(proposals, gt, means=[0, 0, 0, 0], stds=[1, 1, 1, 1]): gt = gt.float() px = (proposals[..., 0] + proposals[..., 2]) * 0.5 py = (proposals[..., 1] + proposals[..., 3]) * 0.5 - pw = proposals[..., 2] - proposals[..., 0] + 1.0 - ph = proposals[..., 3] - proposals[..., 1] + 1.0 + pw = proposals[..., 2] - proposals[..., 0] + ph = proposals[..., 3] - proposals[..., 1] gx = (gt[..., 0] + gt[..., 2]) * 0.5 gy = (gt[..., 1] + gt[..., 3]) * 0.5 - gw = gt[..., 2] - gt[..., 0] + 1.0 - gh = gt[..., 3] - gt[..., 1] + 1.0 + gw = gt[..., 2] - gt[..., 0] + gh = gt[..., 3] - gt[..., 1] dx = (gx - px) / pw dy = (gy - py) / ph @@ -71,9 +71,9 @@ def delta2bbox(rois, >>> [ 0.7, -1.9, -0.5, 0.3]]) >>> delta2bbox(rois, deltas, max_shape=(32, 32)) tensor([[0.0000, 0.0000, 1.0000, 1.0000], - [0.2817, 0.2817, 4.7183, 4.7183], - [0.0000, 0.6321, 7.3891, 0.3679], - [5.8967, 2.9251, 5.5033, 3.2749]]) + [0.1409, 0.1409, 2.8591, 2.8591], + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) """ means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) @@ -89,8 +89,8 @@ def delta2bbox(rois, px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) # Compute width/height of each roi - pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw) - ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh) + pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw) + ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh) # Use exp(network energy) to enlarge/shrink each roi gw = pw * dw.exp() gh = ph * dh.exp() @@ -98,15 +98,15 @@ def delta2bbox(rois, gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx gy = torch.addcmul(py, 1, ph, dy) # gy = py + ph * dy # Convert center-xy/width/height to top-left, bottom-right - x1 = gx - gw * 0.5 + 0.5 - y1 = gy - gh * 0.5 + 0.5 - x2 = gx + gw * 0.5 - 0.5 - y2 = gy + gh * 0.5 - 0.5 + x1 = gx - gw * 0.5 + y1 = gy - gh * 0.5 + x2 = gx + gw * 0.5 + y2 = gy + gh * 0.5 if max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1] - 1) - y1 = y1.clamp(min=0, max=max_shape[0] - 1) - x2 = x2.clamp(min=0, max=max_shape[1] - 1) - y2 = y2.clamp(min=0, max=max_shape[0] - 1) + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) return bboxes @@ -124,8 +124,8 @@ def bbox_flip(bboxes, img_shape): if isinstance(bboxes, torch.Tensor): assert bboxes.shape[-1] % 4 == 0 flipped = bboxes.clone() - flipped[:, 0::4] = img_shape[1] - bboxes[:, 2::4] - 1 - flipped[:, 2::4] = img_shape[1] - bboxes[:, 0::4] - 1 + flipped[:, 0::4] = img_shape[1] - bboxes[:, 2::4] + flipped[:, 2::4] = img_shape[1] - bboxes[:, 0::4] return flipped elif isinstance(bboxes, np.ndarray): return mmcv.bbox_flip(bboxes, img_shape) @@ -216,8 +216,8 @@ def distance2bbox(points, distance, max_shape=None): x2 = points[:, 0] + distance[:, 2] y2 = points[:, 1] + distance[:, 3] if max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1] - 1) - y1 = y1.clamp(min=0, max=max_shape[0] - 1) - x2 = x2.clamp(min=0, max=max_shape[1] - 1) - y2 = y2.clamp(min=0, max=max_shape[0] - 1) + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) return torch.stack([x1, y1, x2, y2], -1) diff --git a/mmdet/core/evaluation/bbox_overlaps.py b/mmdet/core/evaluation/bbox_overlaps.py index ad4c70523fdaa5d89a2b80ada559e1822d0ecd22..5507e88c007f946ac689d1ef541493249539554a 100644 --- a/mmdet/core/evaluation/bbox_overlaps.py +++ b/mmdet/core/evaluation/bbox_overlaps.py @@ -28,17 +28,15 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou'): bboxes1, bboxes2 = bboxes2, bboxes1 ious = np.zeros((cols, rows), dtype=np.float32) exchange = True - area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( - bboxes1[:, 3] - bboxes1[:, 1] + 1) - area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( - bboxes2[:, 3] - bboxes2[:, 1] + 1) + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) for i in range(bboxes1.shape[0]): x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) - overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( - y_end - y_start + 1, 0) + overlap = np.maximum(x_end - x_start, 0) * np.maximum( + y_end - y_start, 0) if mode == 'iou': union = area1[i] + area2 - overlap else: diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py index fedc51d631372104bb51a9f3e8dc3f0f4b0af0c0..b9843a78ed5cfbbb726b6022785a628b815b935d 100644 --- a/mmdet/core/evaluation/mean_ap.py +++ b/mmdet/core/evaluation/mean_ap.py @@ -98,14 +98,14 @@ def tpfp_imagenet(det_bboxes, if area_ranges == [(None, None)]: fp[...] = 1 else: - det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * ( - det_bboxes[:, 3] - det_bboxes[:, 1] + 1) + det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0]) * ( + det_bboxes[:, 3] - det_bboxes[:, 1]) for i, (min_area, max_area) in enumerate(area_ranges): fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 return tp, fp ious = bbox_overlaps(det_bboxes, gt_bboxes - 1) - gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1 - gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1 + gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)), default_iou_thr) # sort all detections by scores in descending order @@ -144,7 +144,7 @@ def tpfp_imagenet(det_bboxes, fp[k, i] = 1 else: bbox = det_bboxes[i, :4] - area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1) + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) if area >= min_area and area < max_area: fp[k, i] = 1 return tp, fp @@ -194,8 +194,8 @@ def tpfp_default(det_bboxes, if area_ranges == [(None, None)]: fp[...] = 1 else: - det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * ( - det_bboxes[:, 3] - det_bboxes[:, 1] + 1) + det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0]) * ( + det_bboxes[:, 3] - det_bboxes[:, 1]) for i, (min_area, max_area) in enumerate(area_ranges): fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 return tp, fp @@ -213,8 +213,8 @@ def tpfp_default(det_bboxes, if min_area is None: gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) else: - gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * ( - gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1) + gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) for i in sort_inds: if ious_max[i] >= iou_thr: @@ -231,7 +231,7 @@ def tpfp_default(det_bboxes, fp[k, i] = 1 else: bbox = det_bboxes[i, :4] - area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1) + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) if area >= min_area and area < max_area: fp[k, i] = 1 return tp, fp @@ -332,8 +332,8 @@ def eval_map(det_results, if area_ranges is None: num_gts[0] += bbox.shape[0] else: - gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * ( - bbox[:, 3] - bbox[:, 1] + 1) + gt_areas = (bbox[:, 2] - bbox[:, 0]) * ( + bbox[:, 3] - bbox[:, 1]) for k, (min_area, max_area) in enumerate(area_ranges): num_gts[k] += np.sum((gt_areas >= min_area) & (gt_areas < max_area)) diff --git a/mmdet/core/mask/mask_target.py b/mmdet/core/mask/mask_target.py index bfc154592ee4b32c1a9a9149bd7b38c8a69448fb..68014889b1898ef1f468ffa0a88f66043c48e55a 100644 --- a/mmdet/core/mask/mask_target.py +++ b/mmdet/core/mask/mask_target.py @@ -13,21 +13,21 @@ def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg): + device = pos_proposals.device mask_size = _pair(cfg.mask_size) num_pos = pos_proposals.size(0) if num_pos > 0: proposals_np = pos_proposals.cpu().numpy() maxh, maxw = gt_masks.height, gt_masks.width - proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw - 1) - proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh - 1) - proposals_np = proposals_np.astype(np.int32) + proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw) + proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh) pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() mask_targets = gt_masks.crop_and_resize( - proposals_np, mask_size, inds=pos_assigned_gt_inds).to_ndarray() + proposals_np, mask_size, device=device, + inds=pos_assigned_gt_inds).to_ndarray() - mask_targets = torch.from_numpy(np.stack(mask_targets)).float().to( - pos_proposals.device) + mask_targets = torch.from_numpy(mask_targets).float().to(device) else: mask_targets = pos_proposals.new_zeros((0, ) + mask_size) diff --git a/mmdet/core/mask/structures.py b/mmdet/core/mask/structures.py index 7975f50212c45dfc4f6b945462746642ce5bd2d2..350d30907a24d5060384b1abe7b13b5dab909649 100644 --- a/mmdet/core/mask/structures.py +++ b/mmdet/core/mask/structures.py @@ -5,6 +5,8 @@ import numpy as np import pycocotools.mask as maskUtils import torch +from mmdet.ops.roi_align import roi_align + class BaseInstanceMasks(metaclass=ABCMeta): @@ -185,11 +187,11 @@ class BitmapMasks(BaseInstanceMasks): # clip the boundary bbox = bbox.copy() - bbox[0::2] = np.clip(bbox[0::2], 0, self.width - 1) - bbox[1::2] = np.clip(bbox[1::2], 0, self.height - 1) + bbox[0::2] = np.clip(bbox[0::2], 0, self.width) + bbox[1::2] = np.clip(bbox[1::2], 0, self.height) x1, y1, x2, y2 = bbox - w = np.maximum(x2 - x1 + 1, 1) - h = np.maximum(y2 - y1 + 1, 1) + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) if len(self.masks) == 0: cropped_masks = np.empty((0, h, w), dtype=np.uint8) @@ -201,6 +203,7 @@ class BitmapMasks(BaseInstanceMasks): bboxes, out_shape, inds, + device='cpu', interpolation='bilinear'): """Crop and resize masks by the given bboxes. @@ -209,9 +212,10 @@ class BitmapMasks(BaseInstanceMasks): assigned bbox and resize to the size of (mask_h, mask_w) Args: - bboxes (ndarray): bboxes in format [x1, y1, x2, y2], shape (N, 4) + bboxes (Tensor): bboxes in format [x1, y1, x2, y2], shape (N, 4) out_shape (tuple[int]): target (h, w) of resized mask inds (ndarray): indexes to assign masks to each bbox + device (str): device of bboxes interpolation (str): see `mmcv.imresize` Return: @@ -221,19 +225,26 @@ class BitmapMasks(BaseInstanceMasks): empty_masks = np.empty((0, *out_shape), dtype=np.uint8) return BitmapMasks(empty_masks, *out_shape) - resized_masks = [] - for i in range(len(bboxes)): - mask = self.masks[inds[i]] - bbox = bboxes[i, :].astype(np.int32) - x1, y1, x2, y2 = bbox - w = np.maximum(x2 - x1 + 1, 1) - h = np.maximum(y2 - y1 + 1, 1) - resized_masks.append( - mmcv.imresize( - mask[y1:y1 + h, x1:x1 + w], - out_shape, - interpolation=interpolation)) - return BitmapMasks(np.stack(resized_masks), *out_shape) + # convert bboxes to tensor + if isinstance(bboxes, np.ndarray): + bboxes = torch.from_numpy(bboxes).to(device=device) + if isinstance(inds, np.ndarray): + inds = torch.from_numpy(inds).to(device=device) + + num_bbox = bboxes.shape[0] + fake_inds = torch.arange( + num_bbox, device=device).to(dtype=bboxes.dtype)[:, None] + rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5 + rois = rois.to(device=device) + if num_bbox > 0: + gt_masks_th = torch.from_numpy(self.masks).to(device).index_select( + 0, inds).to(dtype=rois.dtype) + targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape, + 1.0, 0, True).squeeze(1) + resized_masks = (targets >= 0.5).cpu().numpy() + else: + resized_masks = [] + return BitmapMasks(resized_masks, *out_shape) def expand(self, expanded_h, expanded_w, top, left): """see `transforms.Expand`.""" @@ -355,7 +366,7 @@ class PolygonMasks(BaseInstanceMasks): flipped_poly_per_obj = [] for p in poly_per_obj: p = p.copy() - p[idx::2] = dim - p[idx::2] - 1 + p[idx::2] = dim - p[idx::2] flipped_poly_per_obj.append(p) flipped_masks.append(flipped_poly_per_obj) flipped_masks = PolygonMasks(flipped_masks, self.height, @@ -369,11 +380,11 @@ class PolygonMasks(BaseInstanceMasks): # clip the boundary bbox = bbox.copy() - bbox[0::2] = np.clip(bbox[0::2], 0, self.width - 1) - bbox[1::2] = np.clip(bbox[1::2], 0, self.height - 1) + bbox[0::2] = np.clip(bbox[0::2], 0, self.width) + bbox[1::2] = np.clip(bbox[1::2], 0, self.height) x1, y1, x2, y2 = bbox - w = np.maximum(x2 - x1 + 1, 1) - h = np.maximum(y2 - y1 + 1, 1) + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) if len(self.masks) == 0: cropped_masks = PolygonMasks([], h, w) @@ -402,6 +413,7 @@ class PolygonMasks(BaseInstanceMasks): bboxes, out_shape, inds, + device='cpu', interpolation='bilinear'): """see BitmapMasks.crop_and_resize""" out_h, out_w = out_shape @@ -413,8 +425,8 @@ class PolygonMasks(BaseInstanceMasks): mask = self.masks[inds[i]] bbox = bboxes[i, :].astype(np.int32) x1, y1, x2, y2 = bbox - w = np.maximum(x2 - x1 + 1, 1) - h = np.maximum(y2 - y1 + 1, 1) + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) h_scale = out_h / h w_scale = out_w / w diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py index 56bf2881fd779ea7a6360693fe64f8c3e260bf42..abd03541fab1db546eb334c680d1a6bf333f4e98 100644 --- a/mmdet/datasets/cityscapes.py +++ b/mmdet/datasets/cityscapes.py @@ -60,7 +60,7 @@ class CityscapesDataset(CocoDataset): x1, y1, w, h = ann['bbox'] if ann['area'] <= 0 or w < 1 or h < 1: continue - bbox = [x1, y1, x1 + w - 1, y1 + h - 1] + bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): gt_bboxes_ignore.append(bbox) else: diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index efd5b57dbad1475bf6c62777e97bc83e7992c040..e45c8c3a99c0a321f34274ed7b9cf3af82cb3706 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -86,7 +86,7 @@ class CocoDataset(CustomDataset): x1, y1, w, h = ann['bbox'] if ann['area'] <= 0 or w < 1 or h < 1: continue - bbox = [x1, y1, x1 + w - 1, y1 + h - 1] + bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): gt_bboxes_ignore.append(bbox) else: @@ -122,8 +122,8 @@ class CocoDataset(CustomDataset): return [ _bbox[0], _bbox[1], - _bbox[2] - _bbox[0] + 1, - _bbox[3] - _bbox[1] + 1, + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], ] def _proposal2json(self, results): @@ -249,7 +249,7 @@ class CocoDataset(CustomDataset): if ann.get('ignore', False) or ann['iscrowd']: continue x1, y1, w, h = ann['bbox'] - bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1]) + bboxes.append([x1, y1, x1 + w, y1 + h]) bboxes = np.array(bboxes, dtype=np.float32) if bboxes.shape[0] == 0: bboxes = np.zeros((0, 4)) diff --git a/mmdet/datasets/pipelines/instaboost.py b/mmdet/datasets/pipelines/instaboost.py index 6777d4425ee698548e9b20985bc18dd0b720317a..3b7479552ef62eeb885f0d7e2c839ec99dd05d47 100644 --- a/mmdet/datasets/pipelines/instaboost.py +++ b/mmdet/datasets/pipelines/instaboost.py @@ -44,7 +44,8 @@ class InstaBoost(object): bbox = bboxes[i] mask = masks[i] x1, y1, x2, y2 = bbox - bbox = [x1, y1, x2 - x1 + 1, y2 - y1 + 1] + # assert (x2 - x1) >= 1 and (y2 - y1) >= 1 + bbox = [x1, y1, x2 - x1, y2 - y1] anns.append({ 'category_id': label, 'segmentation': mask, @@ -59,7 +60,10 @@ class InstaBoost(object): gt_masks_ann = [] for ann in anns: x1, y1, w, h = ann['bbox'] - bbox = [x1, y1, x1 + w - 1, y1 + h - 1] + # TODO: more essential bug need to be fixed in instaboost + if w <= 0 or h <= 0: + continue + bbox = [x1, y1, x1 + w, y1 + h] gt_bboxes.append(bbox) gt_labels.append(ann['category_id']) gt_masks_ann.append(ann['segmentation']) @@ -73,6 +77,7 @@ class InstaBoost(object): def __call__(self, results): img = results['img'] + orig_type = img.dtype anns = self._load_anns(results) if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]): try: @@ -81,8 +86,9 @@ class InstaBoost(object): raise ImportError('Please run "pip install instaboostfast" ' 'to install instaboostfast first.') anns, img = instaboost.get_new_data( - anns, img, self.cfg, background=None) - results = self._parse_anns(results, anns, img) + anns, img.astype(np.uint8), self.cfg, background=None) + + results = self._parse_anns(results, anns, img.astype(orig_type)) return results def __repr__(self): diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index f7ef65b0c3481d5d96f75e31dbb6db5fffab249d..81099015b85671c9d64bfdb543827ef9cdefef8b 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -143,8 +143,8 @@ class Resize(object): img_shape = results['img_shape'] for key in results.get('bbox_fields', []): bboxes = results[key] * results['scale_factor'] - bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1] - 1) - bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0] - 1) + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) results[key] = bboxes def _resize_masks(self, results): @@ -215,12 +215,12 @@ class RandomFlip(object): flipped = bboxes.copy() if direction == 'horizontal': w = img_shape[1] - flipped[..., 0::4] = w - bboxes[..., 2::4] - 1 - flipped[..., 2::4] = w - bboxes[..., 0::4] - 1 + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] elif direction == 'vertical': h = img_shape[0] - flipped[..., 1::4] = h - bboxes[..., 3::4] - 1 - flipped[..., 3::4] = h - bboxes[..., 1::4] - 1 + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] else: raise ValueError( 'Invalid flipping direction "{}"'.format(direction)) @@ -372,8 +372,8 @@ class RandomCrop(object): bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h], dtype=np.float32) bboxes = results[key] - bbox_offset - bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1] - 1) - bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0] - 1) + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) results[key] = bboxes # crop semantic seg diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py index b99ca751fd8e8b69884534044d9660339b88074c..41abbd513554dea0217ffccb662cfb69a75ecb31 100644 --- a/mmdet/datasets/xml_style.py +++ b/mmdet/datasets/xml_style.py @@ -47,6 +47,7 @@ class XMLDataset(CustomDataset): label = self.cat2label[name] difficult = int(obj.find('difficult').text) bnd_box = obj.find('bndbox') + # TODO: check whether it is necessary to use int # Coordinates may be float type bbox = [ int(float(bnd_box.find('xmin').text)), diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py index 58f27f772ce556be9f8e747884b150194755e7e9..1f4a69f7399617bf169261a540367554fa308b4c 100644 --- a/mmdet/models/anchor_heads/fcos_head.py +++ b/mmdet/models/anchor_heads/fcos_head.py @@ -369,8 +369,8 @@ class FCOSHead(nn.Module): return gt_labels.new_zeros(num_points), \ gt_bboxes.new_zeros((num_points, 4)) - areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * ( - gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1) + areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) # TODO: figure out why these two are different # areas = areas[None].expand(num_points, num_gts) areas = areas[None].repeat(num_points, 1) diff --git a/mmdet/models/anchor_heads/ga_rpn_head.py b/mmdet/models/anchor_heads/ga_rpn_head.py index 11512ffc578fd4b10f17f6129b517bc887594440..d6e3b44741db2256835bf2ec1def151fd185ecde 100644 --- a/mmdet/models/anchor_heads/ga_rpn_head.py +++ b/mmdet/models/anchor_heads/ga_rpn_head.py @@ -103,8 +103,8 @@ class GARPNHead(GuidedAnchorHead): self.target_stds, img_shape) # filter out too small bboxes if cfg.min_bbox_size > 0: - w = proposals[:, 2] - proposals[:, 0] + 1 - h = proposals[:, 3] - proposals[:, 1] + 1 + w = proposals[:, 2] - proposals[:, 0] + h = proposals[:, 3] - proposals[:, 1] valid_inds = torch.nonzero((w >= cfg.min_bbox_size) & (h >= cfg.min_bbox_size)).squeeze() proposals = proposals[valid_inds, :] diff --git a/mmdet/models/anchor_heads/rpn_head.py b/mmdet/models/anchor_heads/rpn_head.py index f88b949cf8b3051610697d15772bf1b7ea938a06..adcc7d1c15dc28d01a25ecbd09886d2b089de2ff 100644 --- a/mmdet/models/anchor_heads/rpn_head.py +++ b/mmdet/models/anchor_heads/rpn_head.py @@ -82,8 +82,8 @@ class RPNHead(AnchorHead): proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means, self.target_stds, img_shape) if cfg.min_bbox_size > 0: - w = proposals[:, 2] - proposals[:, 0] + 1 - h = proposals[:, 3] - proposals[:, 1] + 1 + w = proposals[:, 2] - proposals[:, 0] + h = proposals[:, 3] - proposals[:, 1] valid_inds = torch.nonzero((w >= cfg.min_bbox_size) & (h >= cfg.min_bbox_size)).squeeze() proposals = proposals[valid_inds, :] diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py index 57113679b45cd4a3da521acc3c21a43077465eb7..dd21a79eff0df577156bc17a7ab9b0d3a58dc51b 100644 --- a/mmdet/models/anchor_heads/ssd_head.py +++ b/mmdet/models/anchor_heads/ssd_head.py @@ -75,7 +75,7 @@ class SSDHead(AnchorHead): for k in range(len(anchor_strides)): base_size = min_sizes[k] stride = anchor_strides[k] - ctr = ((stride - 1) / 2., (stride - 1) / 2.) + ctr = ((stride) / 2., (stride) / 2.) scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])] ratios = [1.] for r in anchor_ratios[k]: diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 36cc74f2ee7c52a02991ac2eb7fa0d9b5672a920..4867f36c80e3c4fcd05f5f59a979ea34b1ea3ac8 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -154,8 +154,8 @@ class BBoxHead(nn.Module): else: bboxes = rois[:, 1:].clone() if img_shape is not None: - bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1) - bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1) + bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1]) + bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0]) if rescale: if isinstance(scale_factor, float): diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py index c19c1d1d6af86926c31c9a28e974ebfc046c2b33..26c042ae695398208af94f27135d5cc565dc1ba9 100644 --- a/mmdet/models/losses/iou_loss.py +++ b/mmdet/models/losses/iou_loss.py @@ -40,13 +40,13 @@ def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3): """ pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5 pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5 - pred_w = pred[:, 2] - pred[:, 0] + 1 - pred_h = pred[:, 3] - pred[:, 1] + 1 + pred_w = pred[:, 2] - pred[:, 0] + pred_h = pred[:, 3] - pred[:, 1] with torch.no_grad(): target_ctrx = (target[:, 0] + target[:, 2]) * 0.5 target_ctry = (target[:, 1] + target[:, 3]) * 0.5 - target_w = target[:, 2] - target[:, 0] + 1 - target_h = target[:, 3] - target[:, 1] + 1 + target_w = target[:, 2] - target[:, 0] + target_h = target[:, 3] - target[:, 1] dx = target_ctrx - pred_ctrx dy = target_ctry - pred_ctry @@ -91,12 +91,12 @@ def giou_loss(pred, target, eps=1e-7): # overlap lt = torch.max(pred[:, :2], target[:, :2]) rb = torch.min(pred[:, 2:], target[:, 2:]) - wh = (rb - lt + 1).clamp(min=0) + wh = (rb - lt).clamp(min=0) overlap = wh[:, 0] * wh[:, 1] # union - ap = (pred[:, 2] - pred[:, 0] + 1) * (pred[:, 3] - pred[:, 1] + 1) - ag = (target[:, 2] - target[:, 0] + 1) * (target[:, 3] - target[:, 1] + 1) + ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1]) + ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) union = ap + ag - overlap + eps # IoU @@ -105,7 +105,7 @@ def giou_loss(pred, target, eps=1e-7): # enclose area enclose_x1y1 = torch.min(pred[:, :2], target[:, :2]) enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:]) - enclose_wh = (enclose_x2y2 - enclose_x1y1 + 1).clamp(min=0) + enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0) enclose_area = enclose_wh[:, 0] * enclose_wh[:, 1] + eps # GIoU diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py index e94b209eb8d6cb2c5d207b9c27de1cc66f89bae8..9353cf0fcc58a68ecc0e576b0f095a81bca27361 100644 --- a/mmdet/models/mask_heads/fcn_mask_head.py +++ b/mmdet/models/mask_heads/fcn_mask_head.py @@ -1,4 +1,3 @@ -import mmcv import numpy as np import pycocotools.mask as mask_util import torch @@ -8,9 +7,15 @@ from torch.nn.modules.utils import _pair from mmdet.core import auto_fp16, force_fp32, mask_target from mmdet.ops import ConvModule, build_upsample_layer from mmdet.ops.carafe import CARAFEPack +from mmdet.ops.grid_sampler import grid_sample from ..builder import build_loss from ..registry import HEADS +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + @HEADS.register_module class FCNMaskHead(nn.Module): @@ -144,7 +149,7 @@ class FCNMaskHead(nn.Module): """Get segmentation masks from mask_pred and bboxes. Args: - mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). + mask_pred (Tensor or ndarray): shape (n, #class, h, w). For single-scale testing, mask_pred is the direct output of model, whose type is Tensor, while for multi-scale testing, it will be converted to numpy array outside of this method. @@ -158,15 +163,15 @@ class FCNMaskHead(nn.Module): list[list]: encoded masks """ if isinstance(mask_pred, torch.Tensor): - mask_pred = mask_pred.sigmoid().cpu().numpy() - assert isinstance(mask_pred, np.ndarray) - # when enabling mixed precision training, mask_pred may be float16 - # numpy array - mask_pred = mask_pred.astype(np.float32) + mask_pred = mask_pred.sigmoid() + else: + mask_pred = det_bboxes.new_tensor(mask_pred) - cls_segms = [[] for _ in range(self.num_classes - 1)] - bboxes = det_bboxes.cpu().numpy()[:, :4] - labels = det_labels.cpu().numpy() + 1 + device = mask_pred.device + cls_segms = [[] for _ in range(self.num_classes) + ] # BG is not included in num_classes + bboxes = det_bboxes[:, :4] + labels = det_labels + 1 # TODO: remove + 1 in cat -1 if rescale: img_h, img_w = ori_shape[:2] @@ -175,34 +180,130 @@ class FCNMaskHead(nn.Module): img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32) scale_factor = 1.0 - for i in range(bboxes.shape[0]): - if not isinstance(scale_factor, (float, np.ndarray)): - scale_factor = scale_factor.cpu().numpy() - bbox = (bboxes[i, :] / scale_factor).astype(np.int32) - label = labels[i] - w = max(bbox[2] - bbox[0] + 1, 1) - h = max(bbox[3] - bbox[1] + 1, 1) + if not isinstance(scale_factor, (float, torch.Tensor)): + scale_factor = bboxes.new_tensor(scale_factor) + bboxes = bboxes / scale_factor - if not self.class_agnostic: - mask_pred_ = mask_pred[i, label, :, :] - else: - mask_pred_ = mask_pred[i, 0, :, :] + N = len(mask_pred) + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == 'cpu': + # CPU is most efficient when they are pasted one by one with + # skip_empty=True, so that it performs minimal number of + # operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, + # but may have memory issue + num_chunks = int( + np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) + assert (num_chunks <= + N), 'Default GPU_MEM_LIMIT is too small; try increasing it' + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) - bbox_mask = mmcv.imresize(mask_pred_, (w, h)) - bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype( - np.uint8) + threshold = rcnn_test_cfg.mask_thr_binary + im_mask = torch.zeros( + N, + img_h, + img_w, + device=device, + dtype=torch.bool if threshold >= 0 else torch.uint8) - if rcnn_test_cfg.get('crop_mask', False): - im_mask = bbox_mask - else: - im_mask = np.zeros((img_h, img_w), dtype=np.uint8) - im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask + if not self.class_agnostic: + mask_pred = mask_pred[range(N), labels][:, None] - if rcnn_test_cfg.get('rle_mask_encode', True): - rle = mask_util.encode( - np.array(im_mask[:, :, np.newaxis], order='F'))[0] - cls_segms[label - 1].append(rle) + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + mask_pred[inds], + bboxes[inds], + img_h, + img_w, + skip_empty=device.type == 'cpu') + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) else: - cls_segms[label - 1].append(im_mask) + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + im_mask[(inds, ) + spatial_inds] = masks_chunk + for i in range(N): + rle = mask_util.encode( + np.array( + im_mask[i][:, :, None].cpu().numpy(), + order='F', + dtype='uint8'))[0] + cls_segms[labels[i] - 1].append(rle) # TODO: remove -1 in cat -1 return cls_segms + + +def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): + """Paste instance masks acoording to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, + min=0).to(dtype=torch.int32) + x1_int = torch.clamp( + boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp( + boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange( + y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 + img_x = torch.arange( + x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + if torch.isinf(img_x).any(): + inds = torch.where(torch.isinf(img_x)) + img_x[inds] = 0 + if torch.isinf(img_y).any(): + inds = torch.where(torch.isinf(img_y)) + img_y[inds] = 0 + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () diff --git a/mmdet/models/mask_heads/grid_head.py b/mmdet/models/mask_heads/grid_head.py index ec669cb2a36b305818620807d29467601f8b0e47..3449bca1d991a9325789b52e6da584652fec8073 100644 --- a/mmdet/models/mask_heads/grid_head.py +++ b/mmdet/models/mask_heads/grid_head.py @@ -355,7 +355,7 @@ class GridHead(nn.Module): bbox_res = torch.cat( [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1) - bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1] - 1) - bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0] - 1) + bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1]) + bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0]) return bbox_res diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py index 6620d1d86a241e1884b8981963d4b4affe6c51cd..54b25cc8151b933eb099074255d4e89f37739375 100644 --- a/mmdet/models/roi_extractors/single_level.py +++ b/mmdet/models/roi_extractors/single_level.py @@ -67,7 +67,7 @@ class SingleRoIExtractor(nn.Module): Tensor: Level index (0-based) of each RoI, shape (k, ) """ scale = torch.sqrt( - (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1)) + (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() return target_lvls @@ -75,14 +75,14 @@ class SingleRoIExtractor(nn.Module): def roi_rescale(self, rois, scale_factor): cx = (rois[:, 1] + rois[:, 3]) * 0.5 cy = (rois[:, 2] + rois[:, 4]) * 0.5 - w = rois[:, 3] - rois[:, 1] + 1 - h = rois[:, 4] - rois[:, 2] + 1 + w = rois[:, 3] - rois[:, 1] + h = rois[:, 4] - rois[:, 2] new_w = w * scale_factor new_h = h * scale_factor - x1 = cx - new_w * 0.5 + 0.5 - x2 = cx + new_w * 0.5 - 0.5 - y1 = cy - new_h * 0.5 + 0.5 - y2 = cy + new_h * 0.5 - 0.5 + x1 = cx - new_w * 0.5 + x2 = cx + new_w * 0.5 + y1 = cy - new_h * 0.5 + y2 = cy + new_h * 0.5 new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) return new_rois diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py index a9ebac22ce98fab77cb305914fc7a5f7b792aef5..9f1a84c43ec13b088a8e33a5c2bb9ad373a86305 100644 --- a/mmdet/ops/nms/nms_wrapper.py +++ b/mmdet/ops/nms/nms_wrapper.py @@ -29,7 +29,7 @@ def nms(dets, iou_thr, device_id=None): >>> [35.6, 11.8, 39.3, 14.2, 0.5], >>> [35.3, 11.5, 39.9, 14.5, 0.4], >>> [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32) - >>> iou_thr = 0.7 + >>> iou_thr = 0.6 >>> suppressed, inds = nms(dets, iou_thr) >>> assert len(inds) == len(suppressed) == 3 """ @@ -84,9 +84,9 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): >>> [3., 1., 3., 1., 0.5], >>> [3., 1., 3., 1., 0.4], >>> [3., 1., 3., 1., 0.0]], dtype=np.float32) - >>> iou_thr = 0.7 + >>> iou_thr = 0.6 >>> new_dets, inds = soft_nms(dets, iou_thr, sigma=0.5) - >>> assert len(inds) == len(new_dets) == 3 + >>> assert len(inds) == len(new_dets) == 5 """ # convert dets (tensor or numpy array) to tensor if isinstance(dets, torch.Tensor): diff --git a/mmdet/ops/nms/src/cpu/nms_cpu.cpp b/mmdet/ops/nms/src/cpu/nms_cpu.cpp index 1fa589dcd840ddd09255df0454aa5255f2dbc70d..4d11abec7e69bf46711115a62daebebb95c54e9a 100644 --- a/mmdet/ops/nms/src/cpu/nms_cpu.cpp +++ b/mmdet/ops/nms/src/cpu/nms_cpu.cpp @@ -1,4 +1,6 @@ -// Modified from https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx, Soft-NMS is added +// Soft-NMS is added by MMDetection. +// Modified from +// https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #include <torch/extension.h> @@ -16,7 +18,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) { auto y2_t = dets.select(1, 3).contiguous(); auto scores = dets.select(1, 4).contiguous(); - at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); + at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); @@ -49,8 +51,8 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) { auto xx2 = std::min(ix2, x2[j]); auto yy2 = std::min(iy2, y2[j]); - auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1); - auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1); + auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1); + auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1); auto inter = w * h; auto ovr = inter / (iarea + areas[j] - inter); if (ovr >= threshold) suppressed[j] = 1; @@ -69,7 +71,8 @@ at::Tensor nms_cpu(const at::Tensor& dets, const float threshold) { template <typename scalar_t> at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, - const unsigned char method, const float sigma, const float min_score) { + const unsigned char method, const float sigma, + const float min_score) { AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor"); if (dets.numel() == 0) { @@ -82,7 +85,7 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, auto y2_t = dets.select(1, 3).contiguous(); auto scores_t = dets.select(1, 4).contiguous(); - at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); + at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); auto ndets = dets.size(0); auto x1 = x1_t.data<scalar_t>(); @@ -110,12 +113,12 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, pos = i + 1; // get max box - while (pos < ndets){ - if (max_score < scores[pos]) { - max_score = scores[pos]; - max_pos = pos; - } - pos = pos + 1; + while (pos < ndets) { + if (max_score < scores[pos]) { + max_score = scores[pos]; + max_pos = pos; + } + pos = pos + 1; } // add max box as a detection x1[i] = x1[max_pos]; @@ -127,10 +130,10 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, inds[i] = inds[max_pos]; // swap ith box with position of max box - x1[max_pos] = ix1; - y1[max_pos] = iy1; - x2[max_pos] = ix2; - y2[max_pos] = iy2; + x1[max_pos] = ix1; + y1[max_pos] = iy1; + x2[max_pos] = ix2; + y2[max_pos] = iy2; scores[max_pos] = iscore; areas[max_pos] = iarea; inds[max_pos] = iind; @@ -143,32 +146,30 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, iarea = areas[i]; pos = i + 1; - // NMS iterations, note that N changes if detection boxes fall below threshold + // NMS iterations, note that N changes if detection boxes fall below + // threshold while (pos < ndets) { auto xx1 = std::max(ix1, x1[pos]); auto yy1 = std::max(iy1, y1[pos]); auto xx2 = std::min(ix2, x2[pos]); auto yy2 = std::min(iy2, y2[pos]); - auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1); - auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1); + auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1); + auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1); auto inter = w * h; auto ovr = inter / (iarea + areas[pos] - inter); scalar_t weight = 1.; if (method == 1) { if (ovr > threshold) weight = 1 - ovr; - } - else if (method == 2) { + } else if (method == 2) { weight = std::exp(-(ovr * ovr) / sigma); - } - else { + } else { // original NMS if (ovr > threshold) { - weight = 0; - } - else { - weight = 1; + weight = 0; + } else { + weight = 1; } } scores[pos] = weight * scores[pos]; @@ -182,7 +183,7 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, scores[pos] = scores[ndets - 1]; areas[pos] = areas[ndets - 1]; inds[pos] = inds[ndets - 1]; - ndets = ndets -1; + ndets = ndets - 1; pos = pos - 1; } pos = pos + 1; @@ -196,15 +197,17 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, result[4] = scores_t.slice(0, 0, ndets); result[5] = inds_t.slice(0, 0, ndets); - result =result.t().contiguous(); + result = result.t().contiguous(); return result; } at::Tensor soft_nms_cpu(const at::Tensor& dets, const float threshold, - const unsigned char method, const float sigma, const float min_score) { + const unsigned char method, const float sigma, + const float min_score) { at::Tensor result; AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] { - result = soft_nms_cpu_kernel<scalar_t>(dets, threshold, method, sigma, min_score); + result = soft_nms_cpu_kernel<scalar_t>(dets, threshold, method, sigma, + min_score); }); return result; } diff --git a/mmdet/ops/nms/src/cuda/nms_kernel.cu b/mmdet/ops/nms/src/cuda/nms_kernel.cu index 8dc98be1971464841440b52bd623637e907727c2..0c084f7a909130059a8fd830339cb7bc077fc273 100644 --- a/mmdet/ops/nms/src/cuda/nms_kernel.cu +++ b/mmdet/ops/nms/src/cuda/nms_kernel.cu @@ -14,10 +14,10 @@ int const threadsPerBlock = sizeof(unsigned long long) * 8; __device__ inline float devIoU(float const * const a, float const * const b) { float left = max(a[0], b[0]), right = min(a[2], b[2]); float top = max(a[1], b[1]), bottom = min(a[3], b[3]); - float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float width = max(right - left, 0.f), height = max(bottom - top, 0.f); float interS = width * height; - float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); - float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + float Sa = (a[2] - a[0]) * (a[3] - a[1]); + float Sb = (b[2] - b[0]) * (b[3] - b[1]); return interS / (Sa + Sb - interS); } diff --git a/mmdet/ops/roi_align/roi_align.py b/mmdet/ops/roi_align/roi_align.py index 203c11528388276cb854490ef70a4ef381e5b3fb..4f792fa9314c018dfe266b29845674431e454993 100644 --- a/mmdet/ops/roi_align/roi_align.py +++ b/mmdet/ops/roi_align/roi_align.py @@ -24,20 +24,18 @@ class RoIAlignFunction(Function): ctx.feature_size = features.size() ctx.aligned = aligned - if features.is_cuda: - if not aligned: - (batch_size, num_channels, data_height, - data_width) = features.size() - num_rois = rois.size(0) - - output = features.new_zeros(num_rois, num_channels, out_h, - out_w) - roi_align_ext.forward_v1(features, rois, out_h, out_w, - spatial_scale, sample_num, output) - else: - output = roi_align_ext.forward_v2(features, rois, - spatial_scale, out_h, out_w, - sample_num, aligned) + if aligned: + output = roi_align_ext.forward_v2(features, rois, spatial_scale, + out_h, out_w, sample_num, + aligned) + elif features.is_cuda: + (batch_size, num_channels, data_height, + data_width) = features.size() + num_rois = rois.size(0) + + output = features.new_zeros(num_rois, num_channels, out_h, out_w) + roi_align_ext.forward_v1(features, rois, out_h, out_w, + spatial_scale, sample_num, output) else: raise NotImplementedError @@ -85,7 +83,7 @@ class RoIAlign(nn.Module): spatial_scale, sample_num=0, use_torchvision=False, - aligned=False): + aligned=True): """ Args: out_size (tuple): h, w diff --git a/mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp b/mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c6b557da24eb19837c8ae8299f1da29dd0e8b80 --- /dev/null +++ b/mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp @@ -0,0 +1,404 @@ +// Modified from +// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#include <ATen/ATen.h> +#include <ATen/TensorUtils.h> + +// implementation taken from Caffe2 +template <typename T> +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template <typename T> +void pre_calc_for_bilinear_interpolate( + const int height, const int width, const int pooled_height, + const int pooled_width, const int iy_upper, const int ix_upper, + T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w, + int roi_bin_grid_h, int roi_bin_grid_w, std::vector<PreCalc<T>>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast<T>(iy + .5f) * bin_size_h / + static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast<T>(ix + .5f) * bin_size_w / + static_cast<T>(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc<T> pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc<T> pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template <typename T> +void ROIAlignForward(const int nthreads, const T* input, const T& spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, const T* rois, T* output, + bool aligned) { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (aligned) { + AT_ASSERTM(roi_width >= 0 && roi_height >= 0, + "ROIs in ROIAlign cannot have non-negative size!"); + } else { // for backward-compatibility only + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); + T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros == 0/1, instead of NaN. + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w * + pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate( + height, width, pooled_height, pooled_width, roi_bin_grid_h, + roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w, + roi_bin_grid_h, roi_bin_grid_w, pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc<T> pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template <typename T> +void bilinear_interpolate_gradient(const int height, const int width, T y, T x, + T& w1, T& w2, T& w3, T& w4, int& x_low, + int& x_high, int& y_low, int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template <class T> +inline void add(T* address, const T& val) { + *address += val; +} + +template <typename T> +void ROIAlignBackward(const int nthreads, const T* grad_output, + const T& spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, T* grad_input, const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride, bool aligned) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (aligned) { + AT_ASSERTM(roi_width >= 0 && roi_height >= 0, + "ROIs in ROIAlign do not have non-negative size!"); + } else { // for backward-compatibility only + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); + T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast<T>(iy + .5f) * bin_size_h / + static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast<T>(ix + .5f) * bin_size_w / + static_cast<T>(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4)); + } // if + } // ix + } // iy + } // for +} // ROIAlignBackward + +at::Tensor ROIAlignForwardV2CPULaucher(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, bool aligned) { + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlignForwardV2CPULaucher"; + at::checkAllSameType(c, {input_t, rois_t}); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) return output; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] { + ROIAlignForward<scalar_t>( + output_size, input.contiguous().data<scalar_t>(), spatial_scale, + channels, height, width, pooled_height, pooled_width, sampling_ratio, + rois.contiguous().data<scalar_t>(), output.data<scalar_t>(), aligned); + }); + return output; +} + +at::Tensor ROIAlignBackwardV2CPULaucher( + const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int batch_size, + const int channels, const int height, const int width, + const int sampling_ratio, bool aligned) { + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlignBackwardV2CPULaucher"; + at::checkAllSameType(c, {grad_t, rois_t}); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] { + ROIAlignBackward<scalar_t>( + grad.numel(), grad.contiguous().data<scalar_t>(), spatial_scale, + channels, height, width, pooled_height, pooled_width, sampling_ratio, + grad_input.data<scalar_t>(), rois.contiguous().data<scalar_t>(), + n_stride, c_stride, h_stride, w_stride, aligned); + }); + return grad_input; +} diff --git a/mmdet/ops/roi_align/src/roi_align_ext.cpp b/mmdet/ops/roi_align/src/roi_align_ext.cpp index 50454d2530b739c5040b526d33be9304377c9915..f01351a8f16c6989ff9916ba06ac5890dbb3fcc8 100644 --- a/mmdet/ops/roi_align/src/roi_align_ext.cpp +++ b/mmdet/ops/roi_align/src/roi_align_ext.cpp @@ -1,6 +1,5 @@ -#include <torch/extension.h> - #include <ATen/ATen.h> +#include <torch/extension.h> #include <cmath> #include <vector> @@ -34,6 +33,19 @@ at::Tensor ROIAlignBackwardV2Laucher( const int sampling_ratio, bool aligned); #endif +at::Tensor ROIAlignForwardV2CPULaucher(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, bool aligned); + +at::Tensor ROIAlignBackwardV2CPULaucher( + const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int batch_size, + const int channels, const int height, const int width, + const int sampling_ratio, bool aligned); + #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CONTIGUOUS(x) \ AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") @@ -125,7 +137,8 @@ inline at::Tensor ROIAlign_forwardV2(const at::Tensor& input, AT_ERROR("ROIAlignV2 is not compiled with GPU support"); #endif } - AT_ERROR("ROIAlignV2 is not implemented on CPU"); + return ROIAlignForwardV2CPULaucher(input, rois, spatial_scale, pooled_height, + pooled_width, sampling_ratio, aligned); } inline at::Tensor ROIAlign_backwardV2( @@ -142,7 +155,9 @@ inline at::Tensor ROIAlign_backwardV2( AT_ERROR("ROIAlignV2 is not compiled with GPU support"); #endif } - AT_ERROR("ROIAlignV2 is not implemented on CPU"); + return ROIAlignBackwardV2CPULaucher(grad, rois, spatial_scale, pooled_height, + pooled_width, batch_size, channels, + height, width, sampling_ratio, aligned); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/setup.py b/setup.py index 7c01e6be3bf808ab908d77888b6d4de310a318b6..14af9d1bce63322faebf28a5751f554b02eff009 100755 --- a/setup.py +++ b/setup.py @@ -240,7 +240,10 @@ if __name__ == '__main__': make_cuda_ext( name='roi_align_ext', module='mmdet.ops.roi_align', - sources=['src/roi_align_ext.cpp'], + sources=[ + 'src/roi_align_ext.cpp', + 'src/cpu/roi_align_v2.cpp', + ], sources_cuda=[ 'src/cuda/roi_align_kernel.cu', 'src/cuda/roi_align_kernel_v2.cu' diff --git a/tests/test_assigner.py b/tests/test_assigner.py index cb783e83c002b042ee2f1e819776259f691cbcc6..a3904b38e109f2718201c9644bd5c51e9360edfa 100644 --- a/tests/test_assigner.py +++ b/tests/test_assigner.py @@ -49,7 +49,7 @@ def test_max_iou_assigner_with_ignore(): [0, 0, 10, 10], [10, 10, 20, 20], [5, 5, 15, 15], - [32, 32, 38, 42], + [30, 32, 40, 42], ]) gt_bboxes = torch.FloatTensor([ [0, 0, 10, 9], diff --git a/tests/test_forward.py b/tests/test_forward.py index 2ace7bbd480cd3b88198b7f1fde3975a5fc6b6c8..5a1e894817cfc63f4e3c34c7feddbe10fdaddbd2 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -179,9 +179,6 @@ def test_cascade_forward(): model, train_cfg, test_cfg = _get_detector_cfg( 'cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py') model['pretrained'] = None - # torchvision roi align supports CPU - model['roi_head']['bbox_roi_extractor']['roi_layer'][ - 'use_torchvision'] = True from mmdet.models import build_detector detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) @@ -233,9 +230,6 @@ def test_faster_rcnn_forward(): model, train_cfg, test_cfg = _get_detector_cfg( 'faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py') model['pretrained'] = None - # torchvision roi align supports CPU - model['roi_head']['bbox_roi_extractor']['roi_layer'][ - 'use_torchvision'] = True from mmdet.models import build_detector detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) @@ -287,9 +281,6 @@ def test_faster_rcnn_ohem_forward(): model, train_cfg, test_cfg = _get_detector_cfg( 'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py') model['pretrained'] = None - # torchvision roi align supports CPU - model['roi_head']['bbox_roi_extractor']['roi_layer'][ - 'use_torchvision'] = True from mmdet.models import build_detector detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) diff --git a/tests/test_masks.py b/tests/test_masks.py index 7c0c8c12587cb745ee7d8b0f326875ed4ae5d3c7..11b2e8e6a4f46af842f032fcd5c06cba11c98a61 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -36,7 +36,7 @@ def dummy_bboxes(num, max_height, max_width): x1y1 = np.random.randint(0, min(max_height // 2, max_width // 2), (num, 2)) wh = np.random.randint(0, min(max_height // 2, max_width // 2), (num, 2)) x2y2 = x1y1 + wh - return np.concatenate([x1y1, x2y2], axis=1).squeeze() + return np.concatenate([x1y1, x2y2], axis=1).squeeze().astype(np.float32) def test_bitmap_mask_init(): @@ -174,18 +174,18 @@ def test_bitmap_mask_crop(): bitmap_masks = BitmapMasks(raw_masks, 28, 28) cropped_masks = bitmap_masks.crop(dummy_bbox) assert len(cropped_masks) == 0 - assert cropped_masks.height == 18 - assert cropped_masks.width == 11 + assert cropped_masks.height == 17 + assert cropped_masks.width == 10 # crop with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) bitmap_masks = BitmapMasks(raw_masks, 28, 28) cropped_masks = bitmap_masks.crop(dummy_bbox) assert len(cropped_masks) == 3 - assert cropped_masks.height == 18 - assert cropped_masks.width == 11 + assert cropped_masks.height == 17 + assert cropped_masks.width == 10 x1, y1, x2, y2 = dummy_bbox - assert (cropped_masks.masks == raw_masks[:, y1:y2 + 1, x1:x2 + 1]).all() + assert (cropped_masks.masks == raw_masks[:, y1:y2, x1:x2]).all() # crop with invalid bbox with pytest.raises(AssertionError): @@ -453,9 +453,9 @@ def test_polygon_mask_crop(): polygon_masks = PolygonMasks(raw_masks, 28, 28) cropped_masks = polygon_masks.crop(dummy_bbox) assert len(cropped_masks) == 0 - assert cropped_masks.height == 18 - assert cropped_masks.width == 11 - assert cropped_masks.to_ndarray().shape == (0, 18, 11) + assert cropped_masks.height == 17 + assert cropped_masks.width == 10 + assert cropped_masks.to_ndarray().shape == (0, 17, 10) # crop with polygon masks contain 1 instances raw_masks = [[np.array([1., 3., 5., 1., 5., 6., 1, 6])]] @@ -463,11 +463,10 @@ def test_polygon_mask_crop(): bbox = np.array([0, 0, 3, 4]) cropped_masks = polygon_masks.crop(bbox) assert len(cropped_masks) == 1 - assert cropped_masks.height == 5 - assert cropped_masks.width == 4 - assert cropped_masks.to_ndarray().shape == (1, 5, 4) - truth = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1], - [0, 1, 1, 1]]) + assert cropped_masks.height == 4 + assert cropped_masks.width == 3 + assert cropped_masks.to_ndarray().shape == (1, 4, 3) + truth = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]]) assert (cropped_masks.to_ndarray() == truth).all() # crop with invalid bbox diff --git a/tests/test_nms.py b/tests/test_nms.py index 6861f1e592f846905859441f6366ae852101e997..e99af88e793f3b3c1621eb8ee1740949d54984c1 100644 --- a/tests/test_nms.py +++ b/tests/test_nms.py @@ -13,7 +13,7 @@ def test_nms_device_and_dtypes_cpu(): CommandLine: xdoctest -m tests/test_nms.py test_nms_device_and_dtypes_cpu """ - iou_thr = 0.7 + iou_thr = 0.6 base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9], [49.3, 32.9, 51.0, 35.3, 0.9], [35.3, 11.5, 39.9, 14.5, 0.4], @@ -23,22 +23,22 @@ def test_nms_device_and_dtypes_cpu(): dets = base_dets.astype(np.float32) supressed, inds = nms(dets, iou_thr) assert dets.dtype == supressed.dtype - assert len(inds) == len(supressed) == 3 + assert len(inds) == len(supressed) == 2 dets = torch.FloatTensor(base_dets) surpressed, inds = nms(dets, iou_thr) assert dets.dtype == surpressed.dtype - assert len(inds) == len(surpressed) == 3 + assert len(inds) == len(surpressed) == 2 dets = base_dets.astype(np.float64) supressed, inds = nms(dets, iou_thr) assert dets.dtype == supressed.dtype - assert len(inds) == len(supressed) == 3 + assert len(inds) == len(supressed) == 2 dets = torch.DoubleTensor(base_dets) surpressed, inds = nms(dets, iou_thr) assert dets.dtype == surpressed.dtype - assert len(inds) == len(surpressed) == 3 + assert len(inds) == len(surpressed) == 2 def test_nms_device_and_dtypes_gpu(): @@ -50,7 +50,7 @@ def test_nms_device_and_dtypes_gpu(): import pytest pytest.skip('test requires GPU and torch+cuda') - iou_thr = 0.7 + iou_thr = 0.6 base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9], [49.3, 32.9, 51.0, 35.3, 0.9], [35.3, 11.5, 39.9, 14.5, 0.4], @@ -62,9 +62,9 @@ def test_nms_device_and_dtypes_gpu(): dets = base_dets.astype(np.float32) supressed, inds = nms(dets, iou_thr, device_id) assert dets.dtype == supressed.dtype - assert len(inds) == len(supressed) == 3 + assert len(inds) == len(supressed) == 2 dets = torch.FloatTensor(base_dets).to(device_id) surpressed, inds = nms(dets, iou_thr) assert dets.dtype == surpressed.dtype - assert len(inds) == len(surpressed) == 3 + assert len(inds) == len(surpressed) == 2 diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 53ef380feb05945bad832c8f41257525e5239dbd..5afa16a8409be0b1d8cc4e0c5d987f3290261a52 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -102,9 +102,7 @@ def _context_for_ohem(): model, train_cfg, test_cfg = _get_detector_cfg( 'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py') model['pretrained'] = None - # torchvision roi align supports CPU - model['roi_head']['bbox_roi_extractor']['roi_layer'][ - 'use_torchvision'] = True + from mmdet.models import build_detector context = build_detector( model, train_cfg=train_cfg, test_cfg=test_cfg).roi_head