Skip to content
Snippets Groups Projects
Unverified Commit b347bf22 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

[Refactor] refactor get_subset_by_classes in dataloader for training with empty-GT images (#3695)


* Add regression test and test data

* Add fix for empty gt images

* Trigger CI build

* refactor get_subset_by_classes

* reformat and fix docstring

* [refactor]: move get_subset_by_classes to _filter_imgs

* make img_ids consistent

* resolve comments

* simplify logics

* add warning

* add warning

* add warning

Co-authored-by: default avatarmmeendez8 <miguelmndez@gmail.com>
parent b3f1e05c
No related branches found
No related tags found
No related merge requests found
......@@ -24,17 +24,29 @@ class CityscapesDataset(CocoDataset):
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = img_info['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
all_iscrowd = all([_['iscrowd'] for _ in ann_info])
if self.filter_empty_gt and (self.img_ids[i] not in ids_with_ann
if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat
or all_iscrowd):
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
def _parse_ann_info(self, img_info, ann_info):
......
......@@ -96,39 +96,27 @@ class CocoDataset(CustomDataset):
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann:
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
def get_subset_by_classes(self):
"""Get img ids that contain any category in class_ids.
Different from the coco.getImgIds(), this function returns the id if
the img contains one of the categories rather than all.
Args:
class_ids (list[int]): list of category ids
Return:
ids (list[int]): integer list of img ids
"""
ids = set()
for i, class_id in enumerate(self.cat_ids):
ids |= set(self.coco.cat_img_map[class_id])
self.img_ids = list(ids)
data_infos = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
return data_infos
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
......
import os.path as osp
import warnings
import mmcv
import numpy as np
......@@ -42,7 +43,9 @@ class CustomDataset(Dataset):
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
test_mode (bool, optional): If set True, annotation will not be loaded.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes will be filtered out.
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
"""
CLASSES = None
......@@ -80,23 +83,21 @@ class CustomDataset(Dataset):
self.proposal_file)
# load annotations (and proposals)
self.data_infos = self.load_annotations(self.ann_file)
# filter data infos if classes are customized
if self.custom_classes:
self.data_infos = self.get_subset_by_classes()
if self.proposal_file is not None:
self.proposals = self.load_proposals(self.proposal_file)
else:
self.proposals = None
# filter images too small
# filter images too small and containing no annotations
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# set group flag for the sampler
if not self.test_mode:
# set group flag for the sampler
self._set_group_flag()
# processing pipeline
self.pipeline = Compose(pipeline)
......@@ -147,6 +148,9 @@ class CustomDataset(Dataset):
def _filter_imgs(self, min_size=32):
"""Filter images too small."""
if self.filter_empty_gt:
warnings.warn(
'CustomDataset does not support filtering empty gt images.')
valid_inds = []
for i, img_info in enumerate(self.data_infos):
if min(img_info['width'], img_info['height']) >= min_size:
......@@ -237,12 +241,13 @@ class CustomDataset(Dataset):
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
cls.custom_classes = False
return cls.CLASSES
cls.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
......@@ -253,9 +258,6 @@ class CustomDataset(Dataset):
return class_names
def get_subset_by_classes(self):
return self.data_infos
def format_results(self, results, **kwargs):
"""Place holder to format result to dataset specific output."""
pass
......
......@@ -58,22 +58,26 @@ class XMLDataset(CustomDataset):
return data_infos
def get_subset_by_classes(self):
"""Filter imgs by user-defined categories."""
subset_data_infos = []
for data_info in self.data_infos:
img_id = data_info['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
if name in self.CLASSES:
subset_data_infos.append(data_info)
break
return subset_data_infos
def _filter_imgs(self, min_size=32):
"""Filter images too small or without annotation."""
valid_inds = []
for i, img_info in enumerate(self.data_infos):
if min(img_info['width'], img_info['height']) < min_size:
continue
if self.filter_empty_gt:
img_id = img_info['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
if name in self.CLASSES:
valid_inds.append(i)
break
else:
valid_inds.append(i)
return valid_inds
def get_ann_info(self, idx):
"""Get annotation from XML file by index.
......
{
"images": [
{
"file_name": "fake1.jpg",
"height": 800,
"width": 800,
"id": 0
},
{
"file_name": "fake2.jpg",
"height": 800,
"width": 800,
"id": 1
},
{
"file_name": "fake3.jpg",
"height": 800,
"width": 800,
"id": 2
}
],
"annotations": [
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 1,
"image_id": 0
},
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 2,
"id": 2,
"image_id": 0
},
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 3,
"image_id": 1
}
],
"categories": [
{
"id": 1,
"name": "bus",
"supercategory": "none"
},
{
"id": 2,
"name": "car",
"supercategory": "none"
}
],
"licenses": [],
"info": null
}
......@@ -204,11 +204,18 @@ def test_dataset_evaluation():
tmp_dir.cleanup()
@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock)
@pytest.mark.parametrize('dataset',
['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
def test_custom_classes_override_default(dataset):
dataset_class = DATASETS.get(dataset)
dataset_class.load_annotations = MagicMock()
if dataset in ['CocoDataset', 'CityscapesDataset']:
dataset_class.coco = MagicMock()
dataset_class.cat_ids = MagicMock()
......@@ -225,7 +232,6 @@ def test_custom_classes_override_default(dataset):
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ('bus', 'car')
assert custom_dataset.custom_classes
# Test setting classes as a list
custom_dataset = dataset_class(
......@@ -237,7 +243,6 @@ def test_custom_classes_override_default(dataset):
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
assert custom_dataset.custom_classes
# Test overriding not a subset
custom_dataset = dataset_class(
......@@ -249,7 +254,6 @@ def test_custom_classes_override_default(dataset):
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['foo']
assert custom_dataset.custom_classes
# Test default behavior
custom_dataset = dataset_class(
......@@ -260,7 +264,6 @@ def test_custom_classes_override_default(dataset):
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES == original_classes
assert not custom_dataset.custom_classes
# Test sending file path
import tempfile
......@@ -277,7 +280,6 @@ def test_custom_classes_override_default(dataset):
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
assert custom_dataset.custom_classes
def test_dataset_wrapper():
......@@ -460,3 +462,32 @@ def _build_demo_runner():
runner = EpochBasedRunner(
model=model, work_dir=tmp_dir, logger=logging.getLogger())
return runner
@pytest.mark.parametrize('classes, expected_length', [(['bus'], 2),
(['car'], 1),
(['bus', 'car'], 2)])
def test_allow_empty_images(classes, expected_length):
dataset_class = DATASETS.get('CocoDataset')
# Filter empty images
filtered_dataset = dataset_class(
ann_file='tests/data/coco_sample.json',
img_prefix='tests/data',
pipeline=[],
classes=classes,
filter_empty_gt=True)
# Get all
full_dataset = dataset_class(
ann_file='tests/data/coco_sample.json',
img_prefix='tests/data',
pipeline=[],
classes=classes,
filter_empty_gt=False)
assert len(filtered_dataset) == expected_length
assert len(filtered_dataset.img_ids) == expected_length
assert len(full_dataset) == 3
assert len(full_dataset.img_ids) == 3
assert filtered_dataset.CLASSES == classes
assert full_dataset.CLASSES == classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment