Skip to content
Snippets Groups Projects
Unverified Commit 9ee13ab6 authored by Qiaofei Li's avatar Qiaofei Li Committed by GitHub
Browse files

Add Rotate Augmentation. (#3619)

* add Rotate augmentation init

* add coco_dummy data for unit test of augmentations

* add coco_dummy data for unit test of augmentations

* add Rotate augmentation

* remove duplicated warpAffine declaration.

* remove unnecessary coco_dummy folder

* re-implements rotate augmentation upon mmcv.imrotate

* fix uint test

* reformat

* add supports for PolygonMasks

* handle __init__ conflicts

* unchange

* pull from master and merge rotate.py into auto_augment.py

* add unit test for autoaugment equipped with rotate

* move random_negative_prob to self and reformat assertion message
parent bf01bdd5
No related branches found
No related tags found
No related merge requests found
from abc import ABCMeta, abstractmethod
import cv2
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
......@@ -155,6 +156,25 @@ class BaseInstanceMasks(metaclass=ABCMeta):
"""
pass
@abstractmethod
def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
"""Rotate the masks.
Args:
out_shape (tuple[int]): Shape for output mask, format (h, w).
angle (int | float): Rotation angle in degrees. Positive values
mean counter-clockwise rotation.
center (tuple[float], optional): Center point (w, h) of the
rotation in source image. If not specified, the center of
the image will be used.
scale (int | float): Isotropic scale factor.
fill_val (int | float): Border value. Default 0 for masks.
Returns:
Rotated masks.
"""
pass
class BitmapMasks(BaseInstanceMasks):
"""This class represents masks in the form of bitmaps.
......@@ -355,6 +375,38 @@ class BitmapMasks(BaseInstanceMasks):
(2, 0, 1)).astype(self.masks.dtype)
return BitmapMasks(sheared_masks, *out_shape)
def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
"""Rotate the BitmapMasks.
Args:
out_shape (tuple[int]): Shape for output mask, format (h, w).
angle (int | float): Rotation angle in degrees. Positive values
mean counter-clockwise rotation.
center (tuple[float], optional): Center point (w, h) of the
rotation in source image. If not specified, the center of
the image will be used.
scale (int | float): Isotropic scale factor.
fill_val (int | float): Border value. Default 0 for masks.
Returns:
BitmapMasks: Rotated BitmapMasks.
"""
if len(self.masks) == 0:
rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype)
else:
rotated_masks = mmcv.imrotate(
self.masks.transpose((1, 2, 0)),
angle,
center=center,
scale=scale,
border_value=fill_val)
if rotated_masks.ndim == 2:
# case when only one mask, (h, w)
rotated_masks = rotated_masks[:, :, None] # (h, w, 1)
rotated_masks = rotated_masks.transpose(
(2, 0, 1)).astype(self.masks.dtype)
return BitmapMasks(rotated_masks, *out_shape)
@property
def areas(self):
"""See :py:attr:`BaseInstanceMasks.areas`."""
......@@ -588,6 +640,35 @@ class PolygonMasks(BaseInstanceMasks):
sheared_masks = PolygonMasks(sheared_masks, *out_shape)
return sheared_masks
def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
"""See :func:`BaseInstanceMasks.rotate`."""
if len(self.masks) == 0:
rotated_masks = PolygonMasks([], *out_shape)
else:
rotated_masks = []
rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale)
for poly_per_obj in self.masks:
rotated_poly = []
for p in poly_per_obj:
p = p.copy()
coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2]
# pad 1 to convert from format [x, y] to homogeneous
# coordinates format [x, y, 1]
coords = np.concatenate(
(coords, np.ones((coords.shape[0], 1), coords.dtype)),
axis=1) # [n, 3]
rotated_coords = np.matmul(
rotate_matrix[None, :, :],
coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2]
rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0,
out_shape[1])
rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0,
out_shape[0])
rotated_poly.append(rotated_coords.reshape(-1))
rotated_masks.append(rotated_poly)
rotated_masks = PolygonMasks(rotated_masks, *out_shape)
return rotated_masks
def to_bitmap(self):
"""convert polygon masks to bitmap masks."""
bitmap_masks = self.to_ndarray()
......
from .auto_augment import AutoAugment, Shear
from .auto_augment import AutoAugment, Rotate, Shear
from .compose import Compose
from .formating import (Collect, DefaultFormatBundle, ImageToTensor,
ToDataContainer, ToTensor, Transpose, to_tensor)
......@@ -17,5 +17,6 @@ __all__ = [
'LoadMultiChannelImageFromFiles', 'LoadProposals', 'MultiScaleFlipAug',
'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale',
'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu',
'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear'
'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear',
'Rotate'
]
import copy
import cv2
import mmcv
import numpy as np
......@@ -317,3 +318,217 @@ class Shear(object):
repr_str += f'random_negative_prob={self.random_negative_prob}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@PIPELINES.register_module()
class Rotate(object):
"""Apply Rotate Transformation to image (and its corresponding bbox, mask,
segmentation).
Args:
level (int | float): The level should be in range (0,_MAX_LEVEL].
scale (int | float): Isotropic scale factor. Same in
``mmcv.imrotate``.
center (int | float | tuple[float]): Center point (w, h) of the
rotation in the source image. If None, the center of the
image will be used. Same in ``mmcv.imrotate``.
img_fill_val (int | float | tuple): The fill value for image border.
If float, the same value will be used for all the three
channels of image. If tuple, the should be 3 elements (e.g.
equals the number of channels for image).
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Default 255.
prob (float): The probability for perform transformation and
should be in range 0 to 1.
max_rotate_angle (int | float): The maximum angles for rotate
transformation.
random_negative_prob (float): The probability that turns the
offset negative.
"""
def __init__(self,
level,
scale=1,
center=None,
img_fill_val=128,
seg_ignore_label=255,
prob=0.5,
max_rotate_angle=30,
random_negative_prob=0.5):
assert isinstance(level, (int, float)), \
f'The level must be type int or float. got {type(level)}.'
assert 0 <= level <= _MAX_LEVEL, \
f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.'
assert isinstance(scale, (int, float)), \
f'The scale must be type int or float. got type {type(scale)}.'
if isinstance(center, (int, float)):
center = (center, center)
elif isinstance(center, tuple):
assert len(center) == 2, 'center with type tuple must have '\
f'2 elements. got {len(center)} elements.'
else:
assert center is None, 'center must be None or type int, '\
f'float or tuple, got type {type(center)}.'
if isinstance(img_fill_val, (float, int)):
img_fill_val = tuple([float(img_fill_val)] * 3)
elif isinstance(img_fill_val, tuple):
assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\
f'have 3 elements. got {len(img_fill_val)}.'
img_fill_val = tuple([float(val) for val in img_fill_val])
else:
raise ValueError(
'img_fill_val must be float or tuple with 3 elements.')
assert np.all([0 <= val <= 255 for val in img_fill_val]), \
'all elements of img_fill_val should between range [0,255]. '\
f'got {img_fill_val}.'
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\
'got {prob}.'
assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\
f'should be type int or float. got type {type(max_rotate_angle)}.'
self.level = level
self.scale = scale
# Rotation angle in degrees. Positive values mean
# clockwise rotation.
self.angle = level_to_value(level, max_rotate_angle)
self.center = center
self.img_fill_val = img_fill_val
self.seg_ignore_label = seg_ignore_label
self.prob = prob
self.max_rotate_angle = max_rotate_angle
self.random_negative_prob = random_negative_prob
def _rotate_img(self, results, angle, center=None, scale=1.0):
"""Rotate the image.
Args:
results (dict): Result dict from loading pipeline.
angle (float): Rotation angle in degrees, positive values
mean clockwise rotation. Same in ``mmcv.imrotate``.
center (tuple[float], optional): Center point (w, h) of the
rotation. Same in ``mmcv.imrotate``.
scale (int | float): Isotropic scale factor. Same in
``mmcv.imrotate``.
"""
for key in results.get('img_fields', ['img']):
img = results[key].copy()
img_rotated = mmcv.imrotate(
img, angle, center, scale, border_value=self.img_fill_val)
results[key] = img_rotated.astype(img.dtype)
def _rotate_bboxes(self, results, rotate_matrix):
"""Rotate the bboxes."""
h, w, c = results['img_shape']
for key in results.get('bbox_fields', []):
min_x, min_y, max_x, max_y = np.split(
results[key], results[key].shape[-1], axis=-1)
coordinates = np.stack([[min_x, min_y], [max_x, min_y],
[min_x, max_y],
[max_x, max_y]]) # [4, 2, nb_bbox, 1]
# pad 1 to convert from format [x, y] to homogeneous
# coordinates format [x, y, 1]
coordinates = np.concatenate(
(coordinates,
np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)),
axis=1) # [4, 3, nb_bbox, 1]
coordinates = coordinates.transpose(
(2, 0, 1, 3)) # [nb_bbox, 4, 3, 1]
rotated_coords = np.matmul(rotate_matrix,
coordinates) # [nb_bbox, 4, 2, 1]
rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2]
min_x, min_y = np.min(
rotated_coords[:, :, 0], axis=1), np.min(
rotated_coords[:, :, 1], axis=1)
max_x, max_y = np.max(
rotated_coords[:, :, 0], axis=1), np.max(
rotated_coords[:, :, 1], axis=1)
min_x, min_y = np.clip(
min_x, a_min=0, a_max=w), np.clip(
min_y, a_min=0, a_max=h)
max_x, max_y = np.clip(
max_x, a_min=min_x, a_max=w), np.clip(
max_y, a_min=min_y, a_max=h)
results[key] = np.stack([min_x, min_y, max_x, max_y],
axis=-1).astype(results[key].dtype)
def _rotate_masks(self,
results,
angle,
center=None,
scale=1.0,
fill_val=0):
"""Rotate the masks."""
h, w, c = results['img_shape']
for key in results.get('mask_fields', []):
masks = results[key]
results[key] = masks.rotate((h, w), angle, center, scale, fill_val)
def _rotate_seg(self,
results,
angle,
center=None,
scale=1.0,
fill_val=255):
"""Rotate the segmentation map."""
for key in results.get('seg_fields', []):
seg = results[key].copy()
results[key] = mmcv.imrotate(
seg, angle, center, scale,
border_value=fill_val).astype(seg.dtype)
def _filter_invalid(self, results, min_bbox_size=0):
"""Filter bboxes and corresponding masks too small after rotate
augmentation."""
bbox2label, bbox2mask, _ = bbox2fields()
for key in results.get('bbox_fields', []):
bbox_w = results[key][:, 2] - results[key][:, 0]
bbox_h = results[key][:, 3] - results[key][:, 1]
valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
valid_inds = np.nonzero(valid_inds)[0]
results[key] = results[key][valid_inds]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key = bbox2label.get(key)
if label_key in results:
results[label_key] = results[label_key][valid_inds]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key = bbox2mask.get(key)
if mask_key in results:
results[mask_key] = results[mask_key][valid_inds]
def __call__(self, results):
"""Call function to rotate images, bounding boxes, masks and semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""
if np.random.rand() > self.prob:
return results
h, w = results['img'].shape[:2]
center = self.center
if center is None:
center = ((w - 1) * 0.5, (h - 1) * 0.5)
angle = random_negative(self.angle, self.random_negative_prob)
self._rotate_img(results, angle, center, self.scale)
rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale)
self._rotate_bboxes(results, rotate_matrix)
self._rotate_masks(results, angle, center, self.scale, fill_val=0)
self._rotate_seg(
results, angle, center, self.scale, fill_val=self.seg_ignore_label)
self._filter_invalid(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(level={self.level}, '
repr_str += f'scale={self.scale}, '
repr_str += f'center={self.center}, '
repr_str += f'img_fill_val={self.img_fill_val}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
repr_str += f'prob={self.prob}, '
repr_str += f'max_rotate_angle={self.max_rotate_angle}, '
repr_str += f'random_negative_prob={self.random_negative_prob})'
return repr_str
import copy
import numpy as np
import pytest
from mmcv.utils import build_from_cfg
from mmdet.core.mask import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
def construct_toy_data(poly2mask=True):
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['img'] = img
results['img_shape'] = img.shape
results['img_fields'] = ['img']
# bboxes
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32)
results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]],
dtype=np.float32)
# labels
results['gt_labels'] = np.array([1], dtype=np.int64)
# masks
results['mask_fields'] = ['gt_masks']
if poly2mask:
gt_masks = np.array([[0, 1, 1, 0], [0, 1, 0, 0]],
dtype=np.uint8)[None, :, :]
results['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
else:
raw_masks = [[np.array([0, 0, 2, 0, 2, 1, 0, 1], dtype=np.float)]]
results['gt_masks'] = PolygonMasks(raw_masks, 2, 4)
# segmentations
results['seg_fields'] = ['gt_semantic_seg']
results['gt_semantic_seg'] = img[..., 0]
return results
def _check_fields(results, results_rotated, keys):
for key in keys:
if isinstance(results[key], (BitmapMasks, PolygonMasks)):
assert np.equal(results[key].to_ndarray(),
results_rotated[key].to_ndarray()).all()
else:
assert np.equal(results[key], results_rotated[key]).all()
def check_rotate(results, results_rotated):
# check image
_check_fields(results, results_rotated, results.get('img_fields', ['img']))
# check bboxes
_check_fields(results, results_rotated, results.get('bbox_fields', []))
# check masks
_check_fields(results, results_rotated, results.get('mask_fields', []))
# check segmentations
_check_fields(results, results_rotated, results.get('seg_fields', []))
# _check gt_labels
if 'gt_labels' in results:
assert np.equal(results['gt_labels'],
results_rotated['gt_labels']).all()
def test_rotate():
# test assertion for invalid type of max_rotate_angle
with pytest.raises(AssertionError):
transform = dict(type='Rotate', level=1, max_rotate_angle=(30, ))
build_from_cfg(transform, PIPELINES)
# test assertion for invalid type of scale
with pytest.raises(AssertionError):
transform = dict(type='Rotate', level=2, scale=(1.2, ))
build_from_cfg(transform, PIPELINES)
# test ValueError for invalid type of img_fill_val
with pytest.raises(ValueError):
transform = dict(
type='Rotate', level=2, img_fill_val=[
128,
])
build_from_cfg(transform, PIPELINES)
# test assertion for invalid number of elements in center
with pytest.raises(AssertionError):
transform = dict(type='Rotate', level=2, center=(0.5, ))
build_from_cfg(transform, PIPELINES)
# test assertion for invalid type of center
with pytest.raises(AssertionError):
transform = dict(type='Rotate', level=2, center=[0, 0])
build_from_cfg(transform, PIPELINES)
# test case when no rotate aug (level=0)
results = construct_toy_data()
img_fill_val = (104, 116, 124)
seg_ignore_label = 255
transform = dict(
type='Rotate',
level=0,
prob=1.,
img_fill_val=img_fill_val,
seg_ignore_label=seg_ignore_label,
)
rotate_module = build_from_cfg(transform, PIPELINES)
results_wo_rotate = rotate_module(copy.deepcopy(results))
check_rotate(results, results_wo_rotate)
# test case when no rotate aug (prob<=0)
transform = dict(
type='Rotate', level=10, prob=0., img_fill_val=img_fill_val, scale=0.6)
rotate_module = build_from_cfg(transform, PIPELINES)
results_wo_rotate = rotate_module(copy.deepcopy(results))
check_rotate(results, results_wo_rotate)
# test clockwise rotation with angle 90
results = construct_toy_data()
img_fill_val = 128
transform = dict(
type='Rotate',
level=10,
max_rotate_angle=90,
img_fill_val=img_fill_val,
# set random_negative_prob to 0 for clockwise rotation
random_negative_prob=0.,
prob=1.)
rotate_module = build_from_cfg(transform, PIPELINES)
results_rotated = rotate_module(copy.deepcopy(results))
img_r = np.array([[img_fill_val, 6, 2, img_fill_val],
[img_fill_val, 7, 3, img_fill_val]]).astype(np.uint8)
img_r = np.stack([img_r, img_r, img_r], axis=-1)
results_gt = copy.deepcopy(results)
results_gt['img'] = img_r
results_gt['gt_bboxes'] = np.array([[1., 0., 2., 1.]], dtype=np.float32)
results_gt['gt_bboxes_ignore'] = np.empty((0, 4), dtype=np.float32)
gt_masks = np.array([[0, 1, 1, 0], [0, 0, 1, 0]],
dtype=np.uint8)[None, :, :]
results_gt['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
results_gt['gt_semantic_seg'] = np.array(
[[255, 6, 2, 255], [255, 7, 3,
255]]).astype(results['gt_semantic_seg'].dtype)
check_rotate(results_gt, results_rotated)
# test clockwise rotation with angle 90, PolygonMasks
results = construct_toy_data(poly2mask=False)
results_rotated = rotate_module(copy.deepcopy(results))
gt_masks = [[np.array([2, 0, 2, 1, 1, 1, 1, 0], dtype=np.float)]]
results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
check_rotate(results_gt, results_rotated)
# test counter-clockwise roatation with angle 90,
# and specify the ratation center
img_fill_val = (104, 116, 124)
transform = dict(
type='Rotate',
level=10,
max_rotate_angle=90,
center=(0, 0),
img_fill_val=img_fill_val,
# set random_negative_prob to 1 for counter-clockwise rotation
random_negative_prob=1.,
prob=1.)
results = construct_toy_data()
rotate_module = build_from_cfg(transform, PIPELINES)
results_rotated = rotate_module(copy.deepcopy(results))
results_gt = copy.deepcopy(results)
h, w = results['img'].shape[:2]
img_r = np.stack([
np.ones((h, w)) * img_fill_val[0],
np.ones((h, w)) * img_fill_val[1],
np.ones((h, w)) * img_fill_val[2]
],
axis=-1).astype(np.uint8)
img_r[0, 0, :] = 1
img_r[0, 1, :] = 5
results_gt['img'] = img_r
results_gt['gt_bboxes'] = np.empty((0, 4), dtype=np.float32)
results_gt['gt_bboxes_ignore'] = np.empty((0, 4), dtype=np.float32)
results_gt['gt_labels'] = np.empty((0, ), dtype=np.int64)
gt_masks = np.empty((0, h, w), dtype=np.uint8)
results_gt['gt_masks'] = BitmapMasks(gt_masks, h, w)
gt_seg = (np.ones((h, w)) * 255).astype(results['gt_semantic_seg'].dtype)
gt_seg[0, 0], gt_seg[0, 1] = 1, 5
results_gt['gt_semantic_seg'] = gt_seg
check_rotate(results_gt, results_rotated)
transform = dict(
type='Rotate',
level=10,
max_rotate_angle=90,
center=(0),
img_fill_val=img_fill_val,
random_negative_prob=1.,
prob=1.)
rotate_module = build_from_cfg(transform, PIPELINES)
results_rotated = rotate_module(copy.deepcopy(results))
check_rotate(results_gt, results_rotated)
# test counter-clockwise roatation with angle 90,
# and specify the ratation center, PolygonMasks
results = construct_toy_data(poly2mask=False)
results_rotated = rotate_module(copy.deepcopy(results))
gt_masks = [[np.array([0, 0, 0, 0, 1, 0, 1, 0], dtype=np.float)]]
results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
check_rotate(results_gt, results_rotated)
# test AutoAugment equipped with Rotate
policies = [[dict(type='Rotate', level=10, prob=1.)]]
autoaug = dict(type='AutoAugment', policies=policies)
autoaug_module = build_from_cfg(autoaug, PIPELINES)
autoaug_module(copy.deepcopy(results))
policies = [[
dict(type='Rotate', level=10, prob=1.),
dict(
type='Rotate',
level=8,
max_rotate_angle=90,
center=(0),
img_fill_val=img_fill_val)
]]
autoaug = dict(type='AutoAugment', policies=policies)
autoaug_module = build_from_cfg(autoaug, PIPELINES)
autoaug_module(copy.deepcopy(results))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment