From 5db9b2e34a882d385440b8706e7a95319b6f99d2 Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Mon, 6 Apr 2020 16:32:57 +0800 Subject: [PATCH] Rename BitMap to Bitmap, perfect unit test. (#2391) * rename BitMap to Bitmap, add input check for polygon * fix test mask * fix test config * complete test * add mask contest test for bitmap resize * update with np.diag * perfect test polygon resize * perfect test polygon crop --- mmdet/core/mask/__init__.py | 4 +- mmdet/core/mask/structures.py | 93 +++++++-- mmdet/datasets/pipelines/loading.py | 4 +- tests/test_config.py | 4 +- tests/test_masks.py | 304 ++++++++++++++++++++-------- 5 files changed, 300 insertions(+), 109 deletions(-) diff --git a/mmdet/core/mask/__init__.py b/mmdet/core/mask/__init__.py index 64a6cbf9..ed729211 100644 --- a/mmdet/core/mask/__init__.py +++ b/mmdet/core/mask/__init__.py @@ -1,7 +1,7 @@ from .mask_target import mask_target -from .structures import BitMapMasks, PolygonMasks +from .structures import BitmapMasks, PolygonMasks from .utils import split_combined_polys __all__ = [ - 'split_combined_polys', 'mask_target', 'BitMapMasks', 'PolygonMasks' + 'split_combined_polys', 'mask_target', 'BitmapMasks', 'PolygonMasks' ] diff --git a/mmdet/core/mask/structures.py b/mmdet/core/mask/structures.py index eb43aef0..7975f502 100644 --- a/mmdet/core/mask/structures.py +++ b/mmdet/core/mask/structures.py @@ -40,6 +40,11 @@ class BaseInstanceMasks(metaclass=ABCMeta): def expand(self, expanded_h, expanded_w, top, left): pass + @property + @abstractmethod + def areas(self): + pass + @abstractmethod def to_ndarray(self): pass @@ -49,7 +54,7 @@ class BaseInstanceMasks(metaclass=ABCMeta): pass -class BitMapMasks(BaseInstanceMasks): +class BitmapMasks(BaseInstanceMasks): """This class represents masks in the form of bitmaps. Args: @@ -78,7 +83,7 @@ class BitMapMasks(BaseInstanceMasks): def __getitem__(self, index): masks = self.masks[index].reshape(-1, self.height, self.width) - return BitMapMasks(masks, self.height, self.width) + return BitmapMasks(masks, self.height, self.width) def __iter__(self): return iter(self.masks) @@ -95,7 +100,7 @@ class BitMapMasks(BaseInstanceMasks): interpolation (str): same as :func:`mmcv.imrescale` Returns: - BitMapMasks: the rescaled masks + BitmapMasks: the rescaled masks """ if len(self.masks) == 0: new_w, new_h = mmcv.rescale_size((self.width, self.height), scale) @@ -106,7 +111,7 @@ class BitMapMasks(BaseInstanceMasks): for mask in self.masks ]) height, width = rescaled_masks.shape[1:] - return BitMapMasks(rescaled_masks, height, width) + return BitmapMasks(rescaled_masks, height, width) def resize(self, out_shape, interpolation='nearest'): """Resize masks to the given out_shape. @@ -116,7 +121,7 @@ class BitMapMasks(BaseInstanceMasks): interpolation (str): see `mmcv.imresize` Returns: - BitMapMasks: the resized masks + BitmapMasks: the resized masks """ if len(self.masks) == 0: resized_masks = np.empty((0, *out_shape), dtype=np.uint8) @@ -125,7 +130,7 @@ class BitMapMasks(BaseInstanceMasks): mmcv.imresize(mask, out_shape, interpolation=interpolation) for mask in self.masks ]) - return BitMapMasks(resized_masks, *out_shape) + return BitmapMasks(resized_masks, *out_shape) def flip(self, flip_direction='horizontal'): """flip masks alone the given direction. @@ -134,7 +139,7 @@ class BitMapMasks(BaseInstanceMasks): flip_direction (str): either 'horizontal' or 'vertical' Returns: - BitMapMasks: the flipped masks + BitmapMasks: the flipped masks """ assert flip_direction in ('horizontal', 'vertical') @@ -145,7 +150,7 @@ class BitMapMasks(BaseInstanceMasks): mmcv.imflip(mask, direction=flip_direction) for mask in self.masks ]) - return BitMapMasks(flipped_masks, self.height, self.width) + return BitmapMasks(flipped_masks, self.height, self.width) def pad(self, out_shape, pad_val=0): """Pad masks to the given size of (h, w). @@ -155,7 +160,7 @@ class BitMapMasks(BaseInstanceMasks): pad_val (int): the padded value Returns: - BitMapMasks: the padded masks + BitmapMasks: the padded masks """ if len(self.masks) == 0: padded_masks = np.empty((0, *out_shape), dtype=np.uint8) @@ -164,7 +169,7 @@ class BitMapMasks(BaseInstanceMasks): mmcv.impad(mask, out_shape, pad_val=pad_val) for mask in self.masks ]) - return BitMapMasks(padded_masks, *out_shape) + return BitmapMasks(padded_masks, *out_shape) def crop(self, bbox): """Crop each mask by the given bbox. @@ -173,7 +178,7 @@ class BitMapMasks(BaseInstanceMasks): bbox (ndarray): bbox in format [x1, y1, x2, y2], shape (4, ) Return: - BitMapMasks: the cropped masks. + BitmapMasks: the cropped masks. """ assert isinstance(bbox, np.ndarray) assert bbox.ndim == 1 @@ -190,7 +195,7 @@ class BitMapMasks(BaseInstanceMasks): cropped_masks = np.empty((0, h, w), dtype=np.uint8) else: cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w] - return BitMapMasks(cropped_masks, h, w) + return BitmapMasks(cropped_masks, h, w) def crop_and_resize(self, bboxes, @@ -214,7 +219,7 @@ class BitMapMasks(BaseInstanceMasks): """ if len(self.masks) == 0: empty_masks = np.empty((0, *out_shape), dtype=np.uint8) - return BitMapMasks(empty_masks, *out_shape) + return BitmapMasks(empty_masks, *out_shape) resized_masks = [] for i in range(len(bboxes)): @@ -228,7 +233,7 @@ class BitMapMasks(BaseInstanceMasks): mask[y1:y1 + h, x1:x1 + w], out_shape, interpolation=interpolation)) - return BitMapMasks(np.stack(resized_masks), *out_shape) + return BitmapMasks(np.stack(resized_masks), *out_shape) def expand(self, expanded_h, expanded_w, top, left): """see `transforms.Expand`.""" @@ -240,7 +245,16 @@ class BitMapMasks(BaseInstanceMasks): dtype=np.uint8) expanded_mask[:, top:top + self.height, left:left + self.width] = self.masks - return BitMapMasks(expanded_mask, expanded_h, expanded_w) + return BitmapMasks(expanded_mask, expanded_h, expanded_w) + + @property + def areas(self): + """Compute area of each instance + + Return: + ndarray: areas of each instance + """ + return self.masks.sum((1, 2)) def to_ndarray(self): return self.masks @@ -297,7 +311,7 @@ class PolygonMasks(BaseInstanceMasks): return len(self.masks) def rescale(self, scale, interpolation=None): - """see BitMapMasks.rescale""" + """see BitmapMasks.rescale""" new_w, new_h = mmcv.rescale_size((self.width, self.height), scale) if len(self.masks) == 0: rescaled_masks = PolygonMasks([], new_h, new_w) @@ -306,7 +320,7 @@ class PolygonMasks(BaseInstanceMasks): return rescaled_masks def resize(self, out_shape, interpolation=None): - """see BitMapMasks.resize""" + """see BitmapMasks.resize""" if len(self.masks) == 0: resized_masks = PolygonMasks([], *out_shape) else: @@ -325,7 +339,7 @@ class PolygonMasks(BaseInstanceMasks): return resized_masks def flip(self, flip_direction='horizontal'): - """see BitMapMasks.flip""" + """see BitmapMasks.flip""" assert flip_direction in ('horizontal', 'vertical') if len(self.masks) == 0: flipped_masks = PolygonMasks([], self.height, self.width) @@ -349,7 +363,7 @@ class PolygonMasks(BaseInstanceMasks): return flipped_masks def crop(self, bbox): - """see BitMapMasks.crop""" + """see BitmapMasks.crop""" assert isinstance(bbox, np.ndarray) assert bbox.ndim == 1 @@ -368,6 +382,7 @@ class PolygonMasks(BaseInstanceMasks): for poly_per_obj in self.masks: cropped_poly_per_obj = [] for p in poly_per_obj: + # pycocotools will clip the boundary p = p.copy() p[0::2] -= bbox[0] p[1::2] -= bbox[1] @@ -388,7 +403,7 @@ class PolygonMasks(BaseInstanceMasks): out_shape, inds, interpolation='bilinear'): - """see BitMapMasks.crop_and_resize""" + """see BitmapMasks.crop_and_resize""" out_h, out_w = out_shape if len(self.masks) == 0: return PolygonMasks([], out_h, out_w) @@ -407,6 +422,7 @@ class PolygonMasks(BaseInstanceMasks): for p in mask: p = p.copy() # crop + # pycocotools will clip the boundary p[0::2] -= bbox[0] p[1::2] -= bbox[1] @@ -420,7 +436,42 @@ class PolygonMasks(BaseInstanceMasks): def to_bitmap(self): """convert polygon masks to bitmap masks""" bitmap_masks = self.to_ndarray() - return BitMapMasks(bitmap_masks, self.height, self.width) + return BitmapMasks(bitmap_masks, self.height, self.width) + + @property + def areas(self): + """Compute areas of masks. + + This func is modified from + https://github.com/facebookresearch/detectron2/blob/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9/detectron2/structures/masks.py#L387 + Only works with Polygons, using the shoelace formula + + Return: + ndarray: areas of each instance + """ # noqa: W501 + area = [] + for polygons_per_obj in self.masks: + area_per_obj = 0 + for p in polygons_per_obj: + area_per_obj += self._polygon_area(p[0::2], p[1::2]) + area.append(area_per_obj) + return np.asarray(area) + + def _polygon_area(self, x, y): + """Compute the area of a component of a polygon. + + Using the shoelace formula: + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + + Args: + x (ndarray): x coordinates of the component + y (ndarray): y coordinates of the component + + Return: + float: the are of the component + """ # noqa: 501 + return 0.5 * np.abs( + np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) def to_ndarray(self): if len(self.masks) == 0: diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index 5acab478..84e49cf8 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -4,7 +4,7 @@ import mmcv import numpy as np import pycocotools.mask as maskUtils -from mmdet.core import BitMapMasks, PolygonMasks +from mmdet.core import BitmapMasks, PolygonMasks from ..registry import PIPELINES @@ -149,7 +149,7 @@ class LoadAnnotations(object): h, w = results['img_info']['height'], results['img_info']['width'] gt_masks = results['ann_info']['masks'] if self.poly2mask: - gt_masks = BitMapMasks( + gt_masks = BitmapMasks( [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) else: gt_masks = PolygonMasks( diff --git a/tests/test_config.py b/tests/test_config.py index 6c780b27..0e451296 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ from os.path import dirname, exists, join, relpath -from mmdet.core import BitMapMasks, PolygonMasks +from mmdet.core import BitmapMasks, PolygonMasks def _get_config_directory(): @@ -98,7 +98,7 @@ def test_config_data_pipeline(): assert mode in ('polygon', 'bitmap') if mode == 'bitmap': masks = np.random.randint(0, 2, (num_obj, h, w), dtype=np.uint8) - masks = BitMapMasks(masks, h, w) + masks = BitmapMasks(masks, h, w) else: masks = [] for i in range(num_obj): diff --git a/tests/test_masks.py b/tests/test_masks.py index a41d3e69..7c0c8c12 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -2,7 +2,7 @@ import numpy as np import pytest import torch -from mmdet.core import BitMapMasks, PolygonMasks +from mmdet.core import BitmapMasks, PolygonMasks def dummy_raw_bitmap_masks(size): @@ -42,96 +42,116 @@ def dummy_bboxes(num, max_height, max_width): def test_bitmap_mask_init(): # init with empty ndarray masks raw_masks = np.empty((0, 28, 28), dtype=np.uint8) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - assert bitmap_masks.masks.shape[0] == 0 - assert bitmap_masks.height == bitmap_masks.masks.shape[1] - assert bitmap_masks.width == bitmap_masks.masks.shape[2] + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + assert len(bitmap_masks) == 0 + assert bitmap_masks.height == 28 + assert bitmap_masks.width == 28 # init with empty list masks raw_masks = [] - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - assert bitmap_masks.masks.shape[0] == 0 - assert bitmap_masks.height == bitmap_masks.masks.shape[1] - assert bitmap_masks.width == bitmap_masks.masks.shape[2] + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + assert len(bitmap_masks) == 0 + assert bitmap_masks.height == 28 + assert bitmap_masks.width == 28 # init with ndarray masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - assert bitmap_masks.masks.shape[0] == 3 - assert bitmap_masks.height == bitmap_masks.masks.shape[1] - assert bitmap_masks.width == bitmap_masks.masks.shape[2] + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + assert len(bitmap_masks) == 3 + assert bitmap_masks.height == 28 + assert bitmap_masks.width == 28 # init with list masks contain 3 instances raw_masks = [dummy_raw_bitmap_masks((28, 28)) for _ in range(3)] - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - assert bitmap_masks.masks.shape[0] == 3 - assert bitmap_masks.height == bitmap_masks.masks.shape[1] - assert bitmap_masks.width == bitmap_masks.masks.shape[2] + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + assert len(bitmap_masks) == 3 + assert bitmap_masks.height == 28 + assert bitmap_masks.width == 28 # init with raw masks of unsupported type with pytest.raises(AssertionError): raw_masks = [[dummy_raw_bitmap_masks((28, 28))]] - BitMapMasks(raw_masks, 28, 28) + BitmapMasks(raw_masks, 28, 28) def test_bitmap_mask_rescale(): # rescale with empty bitmap masks raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) rescaled_masks = bitmap_masks.rescale((56, 72)) + assert len(rescaled_masks) == 0 assert rescaled_masks.height == 56 assert rescaled_masks.width == 56 - # rescale with bitmap masks contain 3 instances - raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - rescaled_masks = bitmap_masks.rescale((56, 72)) - assert rescaled_masks.height == 56 - assert rescaled_masks.width == 56 + # rescale with bitmap masks contain 1 instances + raw_masks = np.array([[[1, 0, 0, 0], [0, 1, 0, 1]]]) + bitmap_masks = BitmapMasks(raw_masks, 2, 4) + rescaled_masks = bitmap_masks.rescale((8, 8)) + assert len(rescaled_masks) == 1 + assert rescaled_masks.height == 4 + assert rescaled_masks.width == 8 + truth = np.array([[[1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 1, 1], [0, 0, 1, 1, 0, 0, 1, 1]]]) + assert (rescaled_masks.masks == truth).all() def test_bitmap_mask_resize(): # resize with empty bitmap masks raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - rescaled_masks = bitmap_masks.resize((56, 72)) - assert rescaled_masks.height == 56 - assert rescaled_masks.width == 72 - - # resize with bitmap masks contain 3 instances - raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - rescaled_masks = bitmap_masks.resize((56, 72)) - assert rescaled_masks.height == 56 - assert rescaled_masks.width == 72 + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + resized_masks = bitmap_masks.resize((56, 72)) + assert len(resized_masks) == 0 + assert resized_masks.height == 56 + assert resized_masks.width == 72 + + # resize with bitmap masks contain 1 instances + raw_masks = np.diag(np.ones(4, dtype=np.uint8))[np.newaxis, ...] + bitmap_masks = BitmapMasks(raw_masks, 4, 4) + resized_masks = bitmap_masks.resize((8, 8)) + assert len(resized_masks) == 1 + assert resized_masks.height == 8 + assert resized_masks.width == 8 + truth = np.array([[[1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 1, 1]]]) + assert (resized_masks.masks == truth).all() def test_bitmap_mask_flip(): # flip with empty bitmap masks raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) flipped_masks = bitmap_masks.flip(flip_direction='horizontal') assert len(flipped_masks) == 0 + assert flipped_masks.height == 28 + assert flipped_masks.width == 28 # horizontally flip with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) flipped_masks = bitmap_masks.flip(flip_direction='horizontal') flipped_flipped_masks = flipped_masks.flip(flip_direction='horizontal') + assert flipped_masks.masks.shape == (3, 28, 28) assert (bitmap_masks.masks == flipped_flipped_masks.masks).all() + assert (flipped_masks.masks == raw_masks[:, :, ::-1]).all() # vertically flip with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) flipped_masks = bitmap_masks.flip(flip_direction='vertical') flipped_flipped_masks = flipped_masks.flip(flip_direction='vertical') + assert len(flipped_masks) == 3 + assert flipped_masks.height == 28 + assert flipped_masks.width == 28 assert (bitmap_masks.masks == flipped_flipped_masks.masks).all() + assert (flipped_masks.masks == raw_masks[:, ::-1, :]).all() def test_bitmap_mask_pad(): # pad with empty bitmap masks raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) padded_masks = bitmap_masks.pad((56, 56)) assert len(padded_masks) == 0 assert padded_masks.height == 56 @@ -139,8 +159,9 @@ def test_bitmap_mask_pad(): # pad with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) padded_masks = bitmap_masks.pad((56, 56)) + assert len(padded_masks) == 3 assert padded_masks.height == 56 assert padded_masks.width == 56 assert (padded_masks.masks[:, 28:, 28:] == 0).all() @@ -150,7 +171,7 @@ def test_bitmap_mask_crop(): # crop with empty bitmap masks dummy_bbox = np.array([0, 10, 10, 27], dtype=np.int) raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) cropped_masks = bitmap_masks.crop(dummy_bbox) assert len(cropped_masks) == 0 assert cropped_masks.height == 18 @@ -158,10 +179,13 @@ def test_bitmap_mask_crop(): # crop with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) cropped_masks = bitmap_masks.crop(dummy_bbox) + assert len(cropped_masks) == 3 assert cropped_masks.height == 18 assert cropped_masks.width == 11 + x1, y1, x2, y2 = dummy_bbox + assert (cropped_masks.masks == raw_masks[:, y1:y2 + 1, x1:x2 + 1]).all() # crop with invalid bbox with pytest.raises(AssertionError): @@ -175,7 +199,7 @@ def test_bitmap_mask_crop_and_resize(): # crop and resize with empty bitmap masks raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) cropped_resized_masks = bitmap_masks.crop_and_resize( dummy_bbox, (56, 56), inds) assert len(cropped_resized_masks) == 0 @@ -184,9 +208,10 @@ def test_bitmap_mask_crop_and_resize(): # crop and resize with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) cropped_resized_masks = bitmap_masks.crop_and_resize( dummy_bbox, (56, 56), inds) + assert len(cropped_resized_masks) == 5 assert cropped_resized_masks.height == 56 assert cropped_resized_masks.width == 56 @@ -194,31 +219,48 @@ def test_bitmap_mask_crop_and_resize(): def test_bitmap_mask_expand(): # expand with empty bitmap masks raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - cropped_masks = bitmap_masks.expand(56, 56, 12, 14) - assert len(cropped_masks) == 0 - assert cropped_masks.height == 56 - assert cropped_masks.width == 56 + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + expanded_masks = bitmap_masks.expand(56, 56, 12, 14) + assert len(expanded_masks) == 0 + assert expanded_masks.height == 56 + assert expanded_masks.width == 56 # expand with bitmap masks contain 3 instances raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) - cropped_masks = bitmap_masks.expand(56, 56, 12, 14) - assert cropped_masks.height == 56 - assert cropped_masks.width == 56 + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + expanded_masks = bitmap_masks.expand(56, 56, 12, 14) + assert len(expanded_masks) == 3 + assert expanded_masks.height == 56 + assert expanded_masks.width == 56 + assert (expanded_masks.masks[:, :12, :14] == 0).all() + assert (expanded_masks.masks[:, 12 + 28:, 14 + 28:] == 0).all() + + +def test_bitmap_mask_area(): + # area of empty bitmap mask + raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + assert bitmap_masks.areas.sum() == 0 + + # area of bitmap masks contain 3 instances + raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) + areas = bitmap_masks.areas + assert len(areas) == 3 + assert (areas == raw_masks.sum((1, 2))).all() def test_bitmap_mask_to_ndarray(): # empty bitmap masks to ndarray raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) ndarray_masks = bitmap_masks.to_ndarray() assert isinstance(ndarray_masks, np.ndarray) assert ndarray_masks.shape == (0, 28, 28) # bitmap masks contain 3 instances to ndarray raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) ndarray_masks = bitmap_masks.to_ndarray() assert isinstance(ndarray_masks, np.ndarray) assert ndarray_masks.shape == (3, 28, 28) @@ -228,14 +270,14 @@ def test_bitmap_mask_to_ndarray(): def test_bitmap_mask_to_tensor(): # empty bitmap masks to tensor raw_masks = dummy_raw_bitmap_masks((0, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) tensor_masks = bitmap_masks.to_tensor(dtype=torch.uint8, device='cpu') assert isinstance(tensor_masks, torch.Tensor) assert tensor_masks.shape == (0, 28, 28) # bitmap masks contain 3 instances to tensor raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) tensor_masks = bitmap_masks.to_tensor(dtype=torch.uint8, device='cpu') assert isinstance(tensor_masks, torch.Tensor) assert tensor_masks.shape == (3, 28, 28) @@ -244,14 +286,14 @@ def test_bitmap_mask_to_tensor(): def test_bitmap_mask_index(): raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) assert (bitmap_masks[0].masks == raw_masks[0]).all() assert (bitmap_masks[range(2)].masks == raw_masks[range(2)]).all() def test_bitmap_mask_iter(): raw_masks = dummy_raw_bitmap_masks((3, 28, 28)) - bitmap_masks = BitMapMasks(raw_masks, 28, 28) + bitmap_masks = BitmapMasks(raw_masks, 28, 28) for i, bitmap_mask in enumerate(bitmap_masks): assert bitmap_mask.shape == (28, 28) assert (bitmap_mask == raw_masks[i]).all() @@ -260,7 +302,7 @@ def test_bitmap_mask_iter(): def test_polygon_mask_init(): # init with empty masks raw_masks = [] - polygon_masks = BitMapMasks(raw_masks, 28, 28) + polygon_masks = BitmapMasks(raw_masks, 28, 28) assert len(polygon_masks) == 0 assert polygon_masks.height == 28 assert polygon_masks.width == 28 @@ -274,6 +316,7 @@ def test_polygon_mask_init(): assert len(polygon_masks) == 3 assert polygon_masks.height == 28 assert polygon_masks.width == 28 + assert polygon_masks.to_ndarray().shape == (3, 28, 28) # init with raw masks of unsupported type with pytest.raises(AssertionError): @@ -289,31 +332,82 @@ def test_polygon_mask_rescale(): raw_masks = dummy_raw_polygon_masks((0, 28, 28)) polygon_masks = PolygonMasks(raw_masks, 28, 28) rescaled_masks = polygon_masks.rescale((56, 72)) + assert len(rescaled_masks) == 0 assert rescaled_masks.height == 56 assert rescaled_masks.width == 56 + assert rescaled_masks.to_ndarray().shape == (0, 56, 56) # rescale with polygon masks contain 3 instances - raw_masks = dummy_raw_polygon_masks((3, 28, 28)) - polygon_masks = PolygonMasks(raw_masks, 28, 28) - rescaled_masks = polygon_masks.rescale((56, 72)) - assert rescaled_masks.height == 56 - assert rescaled_masks.width == 56 + raw_masks = [[np.array([1, 1, 3, 1, 4, 3, 2, 4, 1, 3], dtype=np.float)]] + polygon_masks = PolygonMasks(raw_masks, 5, 5) + rescaled_masks = polygon_masks.rescale((12, 10)) + assert len(rescaled_masks) == 1 + assert rescaled_masks.height == 10 + assert rescaled_masks.width == 10 + assert rescaled_masks.to_ndarray().shape == (1, 10, 10) + truth = np.array( + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], [0, 0, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0], [0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + np.uint8) + assert (rescaled_masks.to_ndarray() == truth).all() def test_polygon_mask_resize(): # resize with empty polygon masks raw_masks = dummy_raw_polygon_masks((0, 28, 28)) polygon_masks = PolygonMasks(raw_masks, 28, 28) - rescaled_masks = polygon_masks.resize((56, 72)) - assert rescaled_masks.height == 56 - assert rescaled_masks.width == 72 - - # resize with polygon masks contain 3 instances - raw_masks = dummy_raw_polygon_masks((3, 28, 28)) - polygon_masks = PolygonMasks(raw_masks, 28, 28) - rescaled_masks = polygon_masks.resize((56, 72)) - assert rescaled_masks.height == 56 - assert rescaled_masks.width == 72 + resized_masks = polygon_masks.resize((56, 72)) + assert len(resized_masks) == 0 + assert resized_masks.height == 56 + assert resized_masks.width == 72 + assert resized_masks.to_ndarray().shape == (0, 56, 72) + + # resize with polygon masks contain 1 instance 1 part + raw_masks1 = [[np.array([1, 1, 3, 1, 4, 3, 2, 4, 1, 3], dtype=np.float)]] + polygon_masks1 = PolygonMasks(raw_masks1, 5, 5) + resized_masks1 = polygon_masks1.resize((10, 10)) + assert len(resized_masks1) == 1 + assert resized_masks1.height == 10 + assert resized_masks1.width == 10 + assert resized_masks1.to_ndarray().shape == (1, 10, 10) + truth1 = np.array( + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], [0, 0, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0], [0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + np.uint8) + assert (resized_masks1.to_ndarray() == truth1).all() + + # resize with polygon masks contain 1 instance 2 part + raw_masks2 = [[ + np.array([0., 0., 1., 0., 1., 1.]), + np.array([1., 1., 2., 1., 2., 2., 1., 2.]) + ]] + polygon_masks2 = PolygonMasks(raw_masks2, 3, 3) + resized_masks2 = polygon_masks2.resize((6, 6)) + assert len(resized_masks2) == 1 + assert resized_masks2.height == 6 + assert resized_masks2.width == 6 + assert resized_masks2.to_ndarray().shape == (1, 6, 6) + truth2 = np.array( + [[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], np.uint8) + assert (resized_masks2.to_ndarray() == truth2).all() + + # resize with polygon masks contain 2 instances + raw_masks3 = [raw_masks1[0], raw_masks2[0]] + polygon_masks3 = PolygonMasks(raw_masks3, 5, 5) + resized_masks3 = polygon_masks3.resize((10, 10)) + assert len(resized_masks3) == 2 + assert resized_masks3.height == 10 + assert resized_masks3.width == 10 + assert resized_masks3.to_ndarray().shape == (2, 10, 10) + truth3 = np.stack([truth1, np.pad(truth2, ((0, 4), (0, 4)), 'constant')]) + assert (resized_masks3.to_ndarray() == truth3).all() def test_polygon_mask_flip(): @@ -322,12 +416,20 @@ def test_polygon_mask_flip(): polygon_masks = PolygonMasks(raw_masks, 28, 28) flipped_masks = polygon_masks.flip(flip_direction='horizontal') assert len(flipped_masks) == 0 + assert flipped_masks.height == 28 + assert flipped_masks.width == 28 + assert flipped_masks.to_ndarray().shape == (0, 28, 28) + # TODO: fixed flip correctness checking after v2.0_coord is merged # horizontally flip with polygon masks contain 3 instances raw_masks = dummy_raw_polygon_masks((3, 28, 28)) polygon_masks = PolygonMasks(raw_masks, 28, 28) flipped_masks = polygon_masks.flip(flip_direction='horizontal') flipped_flipped_masks = flipped_masks.flip(flip_direction='horizontal') + assert len(flipped_masks) == 3 + assert flipped_masks.height == 28 + assert flipped_masks.width == 28 + assert flipped_masks.to_ndarray().shape == (3, 28, 28) assert (polygon_masks.to_ndarray() == flipped_flipped_masks.to_ndarray() ).all() @@ -336,6 +438,10 @@ def test_polygon_mask_flip(): polygon_masks = PolygonMasks(raw_masks, 28, 28) flipped_masks = polygon_masks.flip(flip_direction='vertical') flipped_flipped_masks = flipped_masks.flip(flip_direction='vertical') + assert len(flipped_masks) == 3 + assert flipped_masks.height == 28 + assert flipped_masks.width == 28 + assert flipped_masks.to_ndarray().shape == (3, 28, 28) assert (polygon_masks.to_ndarray() == flipped_flipped_masks.to_ndarray() ).all() @@ -347,13 +453,22 @@ def test_polygon_mask_crop(): polygon_masks = PolygonMasks(raw_masks, 28, 28) cropped_masks = polygon_masks.crop(dummy_bbox) assert len(cropped_masks) == 0 - - # crop with polygon masks contain 3 instances - raw_masks = dummy_raw_polygon_masks((3, 28, 28)) - polygon_masks = PolygonMasks(raw_masks, 28, 28) - cropped_masks = polygon_masks.crop(dummy_bbox) assert cropped_masks.height == 18 assert cropped_masks.width == 11 + assert cropped_masks.to_ndarray().shape == (0, 18, 11) + + # crop with polygon masks contain 1 instances + raw_masks = [[np.array([1., 3., 5., 1., 5., 6., 1, 6])]] + polygon_masks = PolygonMasks(raw_masks, 7, 7) + bbox = np.array([0, 0, 3, 4]) + cropped_masks = polygon_masks.crop(bbox) + assert len(cropped_masks) == 1 + assert cropped_masks.height == 5 + assert cropped_masks.width == 4 + assert cropped_masks.to_ndarray().shape == (1, 5, 4) + truth = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1], + [0, 1, 1, 1]]) + assert (cropped_masks.to_ndarray() == truth).all() # crop with invalid bbox with pytest.raises(AssertionError): @@ -367,13 +482,19 @@ def test_polygon_mask_pad(): polygon_masks = PolygonMasks(raw_masks, 28, 28) padded_masks = polygon_masks.pad((56, 56)) assert len(padded_masks) == 0 + assert padded_masks.height == 56 + assert padded_masks.width == 56 + assert padded_masks.to_ndarray().shape == (0, 56, 56) # pad with polygon masks contain 3 instances - raw_masks = dummy_raw_polygon_masks((0, 28, 28)) + raw_masks = dummy_raw_polygon_masks((3, 28, 28)) polygon_masks = PolygonMasks(raw_masks, 28, 28) padded_masks = polygon_masks.pad((56, 56)) + assert len(padded_masks) == 3 assert padded_masks.height == 56 assert padded_masks.width == 56 + assert padded_masks.to_ndarray().shape == (3, 56, 56) + assert (padded_masks.to_ndarray()[:, 28:, 28:] == 0).all() def test_polygon_mask_expand(): @@ -395,6 +516,7 @@ def test_polygon_mask_crop_and_resize(): assert len(cropped_resized_masks) == 0 assert cropped_resized_masks.height == 56 assert cropped_resized_masks.width == 56 + assert cropped_resized_masks.to_ndarray().shape == (0, 56, 56) # crop and resize with polygon masks contain 3 instances raw_masks = dummy_raw_polygon_masks((3, 28, 28)) @@ -404,6 +526,24 @@ def test_polygon_mask_crop_and_resize(): assert len(cropped_resized_masks) == 5 assert cropped_resized_masks.height == 56 assert cropped_resized_masks.width == 56 + assert cropped_resized_masks.to_ndarray().shape == (5, 56, 56) + + +def test_polygon_mask_area(): + # area of empty polygon masks + raw_masks = dummy_raw_polygon_masks((0, 28, 28)) + polygon_masks = PolygonMasks(raw_masks, 28, 28) + assert polygon_masks.areas.sum() == 0 + + # area of polygon masks contain 1 instance + # here we hack a case that the gap between the area of bitmap and polygon + # is minor + raw_masks = [[np.array([1, 1, 5, 1, 3, 4])]] + polygon_masks = PolygonMasks(raw_masks, 6, 6) + polygon_area = polygon_masks.areas + bitmap_area = polygon_masks.to_bitmap().areas + assert len(polygon_area) == 1 + assert np.isclose(polygon_area, bitmap_area).all() def test_polygon_mask_to_bitmap(): -- GitLab