Skip to content
Snippets Groups Projects
Unverified Commit 206107ed authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Fixed bugs in transform.py (#2810)

* Add upsample_cfg support in FPN

* small fix

* Add multiple extra conv sources

* small logical fix

* Add neck tests for fpn

* Add neck tests for fpn

* fixed several typos

* resolved issues

* Removed extra_convs_source option

* added necks to apis.rst

* change according to comments

* add bbox_fields to expand and min iou crop

* reconfigured configs

* small fix

* added test for random crop and min iou random crop

* small change

* added valid_inds for masks in random crop according to #2802

* added config unit test

* deleted nas fpn config test

* revise according to comments

* add img_fields check in transform.py and test_transform.py

* Added notes for random crop and min iou random crop

* Added notes for random crop and min iou random crop

* Add bbox2label and bbox2mask key correspondence

* Update nas-fpn model
Add fsaf log
parent e05e8583
No related branches found
No related tags found
No related merge requests found
......@@ -13,10 +13,10 @@ it is empirically found that a hard threshold (0.2-0.2) gives a further gain on
| Backbone | ignore range | ms-train| Lr schd |Train Mem (GB)| Train time (s/iter) | Inf time (fps) | box AP | Download |
|:----------:| :-------: |:-------:|:-------:|:------------:|:---------------:|:--------------:|:-------------:|:--------:|
| R-50 | 0.2-0.5 | N | 1x | 3.15 | 0.43 | 12.3 | 37.0 (35.9) | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_pscale0.2_nscale0.5_r50_fpn_1x_coco/fsaf_pscale0.2_nscale0.5_r50_fpn_1x_coco-9ad4c643.pth) |
| R-50 | 0.2-0.2 | N | 1x | 3.15 | 0.43 | 13.0 | 37.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_r50_fpn_1x_coco/fsaf_r50_fpn_1x_coco-94ccc51f.pth) |
| R-101 | 0.2-0.2 | N | 1x | 5.08 | 0.58 | 10.8 | 39.3 (37.9) | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_r101_fpn_1x_coco/fsaf_r101_fpn_1x_coco-9e71098f.pth) |
| X-101 | 0.2-0.2 | N | 1x | 9.38 | 1.23 | 5.6 | 42.4 (41.0) | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_x101_64x4d_fpn_1x_coco/fsaf_x101_64x4d_fpn_1x_coco-e3f6e6fd.pth) |
| R-50 | 0.2-0.5 | N | 1x | 3.15 | 0.43 | 12.3 | 37.0 (35.9) | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_pscale0.2_nscale0.5_r50_fpn_1x_coco/fsaf_pscale0.2_nscale0.5_r50_fpn_1x_coco-9ad4c643.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_pscale0.2_nscale0.5_r50_fpn_1x_coco/fsaf_pscale0.2_nscale0.5_r50_fpn_1x_coco_20200428_122907.log.json) |
| R-50 | 0.2-0.2 | N | 1x | 3.15 | 0.43 | 13.0 | 37.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_r50_fpn_1x_coco/fsaf_r50_fpn_1x_coco-94ccc51f.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_r50_fpn_1x_coco/fsaf_r50_fpn_1x_coco_20200428_072327.log.json)|
| R-101 | 0.2-0.2 | N | 1x | 5.08 | 0.58 | 10.8 | 39.3 (37.9) | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_r101_fpn_1x_coco/fsaf_r101_fpn_1x_coco-9e71098f.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_r101_fpn_1x_coco/fsaf_r101_fpn_1x_coco_20200428_160348.log.json)|
| X-101 | 0.2-0.2 | N | 1x | 9.38 | 1.23 | 5.6 | 42.4 (41.0) | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_x101_64x4d_fpn_1x_coco/fsaf_x101_64x4d_fpn_1x_coco-e3f6e6fd.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/fsaf/fsaf_x101_64x4d_fpn_1x_coco/fsaf_x101_64x4d_fpn_1x_coco_20200428_160424.log.json)|
**Notes:**
- *1x means the model is trained for 12 epochs.*
......
......@@ -18,8 +18,8 @@ We benchmark the new training schedule (crop training, large batch, unfrozen BN,
| Backbone | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
|:-----------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| R-50-FPN | 50e | 12.9 | 22.9 | 37.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_fpn_crop640_50e_coco/retinanet_r50_fpn_crop640_50e_coco_20200130-ad569db4.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_fpn_crop640_50e_coco/retinanet_r50_fpn_crop640_50e_coco_20200130_140229.log.json) |
| R-50-NASFPN | 50e | 13.2 | 23.0 | 40.1 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_nasfpn_crop640_50e_coco/retinanet_r50_nasfpn_crop640_50e_coco_20200131-895d67cb.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_nasfpn_crop640_50e_coco/retinanet_r50_nasfpn_crop640_50e_coco_20200131_113434.log.json) |
| R-50-FPN | 50e | 12.9 | 22.9 | 37.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_fpn_crop640_50e_coco/retinanet_r50_fpn_crop640_50e_coco-9b953d76.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_fpn_crop640_50e_coco/retinanet_r50_fpn_crop640_50e_coco_20200529_095329.log.json) |
| R-50-NASFPN | 50e | 13.2 | 23.0 | 40.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_nasfpn_crop640_50e_coco/retinanet_r50_nasfpn_crop640_50e_coco-0ad1f644.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fpn/retinanet_r50_nasfpn_crop640_50e_coco/retinanet_r50_nasfpn_crop640_50e_coco_20200528_230008.log.json) |
**Note**: We find that it is unstable to train NAS-FPN and there is a small chance that results can be 3% mAP lower.
......@@ -351,11 +351,29 @@ class RandomCrop(object):
Args:
crop_size (tuple): Expected size after cropping, (h, w).
Notes:
- If the image is smaller than the crop size, return the original image
- The keys for bboxes, labels and masks must be aligned. That is,
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
`gt_bboxes_ignore` corresponds to `gt_labels_ignore` and
`gt_masks_ignore`.
- If there are gt bboxes in an image and the cropping area does not
have intersection with any gt bbox, this image is skipped.
"""
def __init__(self, crop_size):
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
# The key correspondence from bboxes to labels and masks.
self.bbox2label = {
'gt_bboxes': 'gt_labels',
'gt_bboxes_ignore': 'gt_labels_ignore'
}
self.bbox2mask = {
'gt_bboxes': 'gt_masks',
'gt_bboxes_ignore': 'gt_masks_ignore'
}
def __call__(self, results):
for key in results.get('img_fields', ['img']):
......@@ -373,35 +391,43 @@ class RandomCrop(object):
results[key] = img
results['img_shape'] = img_shape
valid_flag = False
# crop bboxes accordingly and clip to the image boundary
for key in results.get('bbox_fields', []):
# e.g. gt_bboxes and gt_bboxes_ignore
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
dtype=np.float32)
bboxes = results[key] - bbox_offset
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
results[key] = bboxes
valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
bboxes[:, 3] > bboxes[:, 1])
# When there is no gt bbox, cropping is conducted.
# When the crop is valid, cropping is conducted.
if len(valid_inds) == 0 or valid_inds.any():
valid_flag = True
results[key] = bboxes[valid_inds, :]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key = self.bbox2label.get(key)
if label_key in results:
results[label_key] = results[label_key][valid_inds]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key = self.bbox2mask.get(key)
if mask_key in results:
results[mask_key] = results[mask_key][
valid_inds.nonzero()[0]].crop(
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
# if no gt bbox remains after cropping, just skip this image
# TODO: check whether we can keep the image regardless of the crop.
if 'bbox_fields' in results and not valid_flag:
return None
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]
# filter out the gt bboxes that are completely cropped
if 'gt_bboxes' in results:
gt_bboxes = results['gt_bboxes']
valid_inds = (gt_bboxes[:, 2] > gt_bboxes[:, 0]) & (
gt_bboxes[:, 3] > gt_bboxes[:, 1])
# if no gt bbox remains after cropping, just skip this image
if not np.any(valid_inds):
return None
results['gt_bboxes'] = gt_bboxes[valid_inds, :]
if 'gt_labels' in results:
results['gt_labels'] = results['gt_labels'][valid_inds]
# filter and crop the masks
if 'gt_masks' in results:
results['gt_masks'] = results['gt_masks'].crop(
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
return results
def __repr__(self):
......@@ -463,6 +489,9 @@ class PhotoMetricDistortion(object):
self.hue_delta = hue_delta
def __call__(self, results):
if 'img_fields' in results:
assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed'
img = results['img']
assert img.dtype == np.float32, \
'PhotoMetricDistortion needs the input image of dtype np.float32,'\
......@@ -558,32 +587,38 @@ class Expand(object):
if random.uniform(0, 1) > self.prob:
return results
img, boxes = [results[k] for k in ('img', 'gt_bboxes')]
if 'img_fields' in results:
assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed'
img = results['img']
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
self.mean,
dtype=img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
boxes = boxes + np.tile((left, top), 2).astype(boxes.dtype)
results['img'] = expand_img
results['gt_bboxes'] = boxes
# expand bboxes
for key in results.get('bbox_fields', []):
results[key] += np.tile((left, top), 2).astype(results[key].dtype)
if 'gt_masks' in results:
results['gt_masks'] = results['gt_masks'].expand(
# expand masks
for key in results.get('mask_fields', []):
results[key] = results[key].expand(
int(h * ratio), int(w * ratio), top, left)
# not tested
if 'gt_semantic_seg' in results:
assert self.seg_ignore_label is not None
gt_seg = results['gt_semantic_seg']
# expand segs
for key in results.get('seg_fields', []):
gt_seg = results[key]
expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
self.seg_ignore_label).astype(gt_seg.dtype)
self.seg_ignore_label,
dtype=gt_seg.dtype)
expand_gt_seg[top:top + h, left:left + w] = gt_seg
results['gt_semantic_seg'] = expand_gt_seg
results[key] = expand_gt_seg
return results
def __repr__(self):
......@@ -605,6 +640,11 @@ class MinIoURandomCrop(object):
bounding boxes
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size).
Notes:
The keys for bboxes, labels and masks should be paired. That is,
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
"""
def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
......@@ -612,14 +652,27 @@ class MinIoURandomCrop(object):
self.min_ious = min_ious
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
self.bbox2label = {
'gt_bboxes': 'gt_labels',
'gt_bboxes_ignore': 'gt_labels_ignore'
}
self.bbox2mask = {
'gt_bboxes': 'gt_masks',
'gt_bboxes_ignore': 'gt_masks_ignore'
}
def __call__(self, results):
img, boxes, labels = [
results[k] for k in ('img', 'gt_bboxes', 'gt_labels')
]
if 'img_fields' in results:
assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed'
img = results['img']
assert 'bbox_fields' in results
boxes = [results[key] for key in results['bbox_fields']]
boxes = np.concatenate(boxes, 0)
h, w, c = img.shape
while True:
mode = random.choice(self.sample_mode)
self.mode = mode
if mode == 1:
return results
......@@ -649,36 +702,45 @@ class MinIoURandomCrop(object):
# only adjust boxes and instance masks when the gt is not empty
if len(overlaps) > 0:
# adjust boxes
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = ((center[:, 0] > patch[0]) *
(center[:, 1] > patch[1]) *
(center[:, 0] < patch[2]) *
(center[:, 1] < patch[3]))
def is_center_of_bboxes_in_patch(boxes, patch):
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = ((center[:, 0] > patch[0]) *
(center[:, 1] > patch[1]) *
(center[:, 0] < patch[2]) *
(center[:, 1] < patch[3]))
return mask
mask = is_center_of_bboxes_in_patch(boxes, patch)
if not mask.any():
continue
boxes = boxes[mask]
labels = labels[mask]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)
results['gt_bboxes'] = boxes
results['gt_labels'] = labels
if 'gt_masks' in results:
results['gt_masks'] = results['gt_masks'][
mask.nonzero()[0]].crop(patch)
for key in results.get('bbox_fields', []):
boxes = results[key]
mask = is_center_of_bboxes_in_patch(boxes, patch)
boxes = boxes[mask]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)
results[key] = boxes
# labels
label_key = self.bbox2label.get(key)
if label_key in results:
results[label_key] = results[label_key][mask]
# mask fields
mask_key = self.bbox2mask.get(key)
if mask_key in results:
results[mask_key] = results[mask_key][
mask.nonzero()[0]].crop(patch)
# adjust the img no matter whether the gt is empty before crop
img = img[patch[1]:patch[3], patch[0]:patch[2]]
results['img'] = img
results['img_shape'] = img.shape
# not tested
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = results['gt_semantic_seg'][
patch[1]:patch[3], patch[0]:patch[2]]
# seg fields
for key in results.get('seg_fields', []):
results[key] = results[key][patch[1]:patch[3],
patch[0]:patch[2]]
return results
def __repr__(self):
......@@ -698,6 +760,9 @@ class Corrupt(object):
def __call__(self, results):
if corrupt is None:
raise RuntimeError('imagecorruptions is not installed')
if 'img_fields' in results:
assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed'
results['img'] = corrupt(
results['img'].astype(np.uint8),
corruption_name=self.corruption,
......@@ -814,7 +879,7 @@ class Albu(object):
def __call__(self, results):
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
# TODO: add bbox_fields
if 'bboxes' in results:
# to list of boxes
if isinstance(results['bboxes'], np.ndarray):
......
......@@ -99,7 +99,7 @@ def test_config_data_pipeline():
'foveabox/fovea_align_r50_fpn_gn-head_mstrain_640-800_4x4_2x_coco.py',
'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py',
'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain_1x_coco.py',
'fp16/mask_rcnn_r50_fpn_fp16_1x_coco.py',
'fp16/mask_rcnn_r50_fpn_fp16_1x_coco.py'
]
def dummy_masks(h, w, num_obj=3, mode='bitmap'):
......
......@@ -6,6 +6,7 @@ import numpy as np
import pytest
from mmcv.utils import build_from_cfg
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.datasets.builder import PIPELINES
......@@ -100,6 +101,104 @@ def test_flip():
assert np.equal(original_img, results['img']).all()
def test_random_crop():
# test assertion for invalid random crop
with pytest.raises(AssertionError):
transform = dict(type='RandomCrop', crop_size=(-1, 0))
build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# TODO: add img_fields test
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='RandomCrop', crop_size=(h - 20, w - 20))
crop_module = build_from_cfg(transform, PIPELINES)
results = crop_module(results)
assert results['img'].shape[:2] == (h - 20, w - 20)
# All bboxes should be reserved after crop
assert results['img_shape'][:2] == (h - 20, w - 20)
assert results['gt_bboxes'].shape[0] == 8
assert results['gt_bboxes_ignore'].shape[0] == 2
def area(bboxes):
return np.prod(bboxes[:, 2:4] - bboxes[:, 0:2], axis=1)
assert (area(results['gt_bboxes']) <= area(gt_bboxes)).all()
assert (area(results['gt_bboxes_ignore']) <= area(gt_bboxes_ignore)).all()
def test_min_iou_random_crop():
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(1, w, h)
gt_bboxes_ignore = create_random_bboxes(1, w, h)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='MinIoURandomCrop')
crop_module = build_from_cfg(transform, PIPELINES)
# Test for img_fields
results_test = copy.deepcopy(results)
results_test['img1'] = results_test['img']
results_test['img_fields'] = ['img', 'img1']
with pytest.raises(AssertionError):
crop_module(results_test)
results = crop_module(results)
patch = np.array([0, 0, results['img_shape'][1], results['img_shape'][0]])
ious = bbox_overlaps(patch.reshape(-1, 4),
results['gt_bboxes']).reshape(-1)
ious_ignore = bbox_overlaps(
patch.reshape(-1, 4), results['gt_bboxes_ignore']).reshape(-1)
mode = crop_module.mode
if mode == 1:
assert np.equal(results['gt_bboxes'], gt_bboxes).all()
assert np.equal(results['gt_bboxes_ignore'], gt_bboxes_ignore).all()
else:
assert (ious >= mode).all()
assert (ious_ignore >= mode).all()
def test_pad():
# test assertion if both size_divisor and size is None
with pytest.raises(AssertionError):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment