diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 594de8dcc99b9e4fc0208f327a05910a95a1793c..55097c5b242da66c9735c0b45cd84beefab487b1 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -7,6 +7,8 @@ log_config = dict( # dict(type='TensorboardLoggerHook') ]) # yapf:enable +custom_hooks = [dict(type='NumClassCheckHook')] + dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None diff --git a/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py b/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py index f6c487bf18fe6bcee9a9b7d62ca99a4d98cafa17..0267a81921cffeee2af465564227670d6741b87d 100644 --- a/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py +++ b/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py @@ -20,7 +20,7 @@ model = dict( num_outs=5), bbox_head=dict( type='GARetinaHead', - num_classes=81, + num_classes=80, in_channels=256, stacked_convs=4, feat_channels=256, @@ -144,8 +144,7 @@ data = dict( evaluation = dict(interval=1, metric='bbox') # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) -optimizer_config = dict( - _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) # learning policy lr_config = dict( policy='step', diff --git a/docs/tutorials/customize_dataset.md b/docs/tutorials/customize_dataset.md index 7c4f845e3369b14a8c908f6e1aa9ed4849fc71e0..1887ad53163b852b302118b04b43f71a737e7276 100644 --- a/docs/tutorials/customize_dataset.md +++ b/docs/tutorials/customize_dataset.md @@ -483,4 +483,5 @@ data = dict( - Before MMDetection v2.5.0, the dataset will filter out the empty GT images automatically if the classes are set and there is no way to disable that through config. This is an undesirable behavior and introduces confusion because if the classes are not set, the dataset only filter the empty GT images when `filter_empty_gt=True` and `test_mode=False`. After MMDetection v2.5.0, we decouple the image filtering process and the classes modification, i.e., the dataset will only filter empty GT images when `filter_empty_gt=True` and `test_mode=False`, no matter whether the classes are set. Thus, setting the classes only influences the annotations of classes used for training and users could decide whether to filter empty GT images by themselves. - Since the middle format only has box labels and does not contain the class names, when using `CustomDataset`, users cannot filter out the empty GT images through configs but only do this offline. +- Please remember to modify the `num_classes` in the head when specifying `classes` in dataset. We implemented [NumClassCheckHook](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/utils.py) to check whether the numbers are consistent since v2.9.0(after PR#4508). - The features for setting dataset classes and dataset filtering will be refactored to be more user-friendly in v2.8.0 or v2.9.0 (depends on the progress). diff --git a/docs/tutorials/customize_runtime.md b/docs/tutorials/customize_runtime.md index 3f52b2732e23e5c5104b7f1e46088d9ad9accb19..616ce508aa2bec61c7fc23cd381b64e670b3b96e 100644 --- a/docs/tutorials/customize_runtime.md +++ b/docs/tutorials/customize_runtime.md @@ -264,10 +264,14 @@ By default the hook's priority is set as `NORMAL` during registration. If the hook is already implemented in MMCV, you can directly modify the config to use the hook as below +#### 4. Example: `NumClassCheckHook` + +We implement a customized hook named [NumClassCheckHook](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/utils.py) to check whether the `num_classes` in head matches the length of `CLASSSES` in `dataset`. + +We set it in [default_runtime.py](https://github.com/open-mmlab/mmdetection/blob/master/configs/_base_/default_runtime.py). + ```python -custom_hooks = [ - dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL') -] +custom_hooks = [dict(type='NumClassCheckHook')] ``` ### Modify default runtime hooks diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index 9430ee99c5cc8ce1499ef4e2fb68ca1bb06b824a..996253e27e5e29d7f2d02c222606696b6bfa5ae1 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -1,6 +1,12 @@ import copy import warnings +from mmcv.cnn import VGG +from mmcv.runner.hooks import HOOKS, Hook + +from mmdet.models.dense_heads import GARPNHead, RPNHead +from mmdet.models.roi_heads.mask_heads import FusedSemanticHead + def replace_ImageToTensor(pipelines): """Replace the ImageToTensor transform in a data pipeline to @@ -98,3 +104,50 @@ def get_loading_pipeline(pipeline): 'The data pipeline in your config file must include ' \ 'loading image and annotations related pipeline.' return loading_pipeline_cfg + + +@HOOKS.register_module() +class NumClassCheckHook(Hook): + + def _check_head(self, runner): + """Check whether the `num_classes` in head matches the length of + `CLASSSES` in `dataset`. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + model = runner.model + dataset = runner.data_loader.dataset + if dataset.CLASSES is None: + runner.logger.warning( + f'Please set `CLASSES` ' + f'in the {dataset.__class__.__name__} and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + for name, module in model.named_modules(): + if hasattr(module, 'num_classes') and not isinstance( + module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)): + assert module.num_classes == len(dataset.CLASSES), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of `CLASSES` ' + f'{len(dataset.CLASSES)}) in ' + f'{dataset.__class__.__name__}') + + def before_train_epoch(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + self._check_head(runner) + + def before_val_epoch(self, runner): + """Check whether the dataset in val epoch is compatible with head. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + self._check_head(runner) diff --git a/tests/test_config.py b/tests/test_config.py index c747b79639cee34474482342a3ba9777cd2d4214..8891103ad2edfc1b9ab2e19198ecd3501b35aa0a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,13 @@ from os.path import dirname, exists, join, relpath +from unittest.mock import Mock import pytest import torch from mmcv.runner import build_optimizer from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.utils import NumClassCheckHook def _get_config_directory(): @@ -22,6 +25,42 @@ def _get_config_directory(): return config_dpath +def _check_numclasscheckhook(detector, config_mod): + + dummy_runner = Mock() + dummy_runner.model = detector + + def get_dataset_name_classes(dataset): + # deal with `RepeatDataset`,`ConcatDataset`,`ClassBalancedDataset`.. + if isinstance(dataset, (list, tuple)): + dataset = dataset[0] + while ('dataset' in dataset): + dataset = dataset['dataset'] + # ConcatDataset + if isinstance(dataset, (list, tuple)): + dataset = dataset[0] + return dataset['type'], dataset.get('classes', None) + + compatible_check = NumClassCheckHook() + dataset_name, CLASSES = get_dataset_name_classes( + config_mod['data']['train']) + if CLASSES is None: + CLASSES = DATASETS.get(dataset_name).CLASSES + dummy_runner.data_loader.dataset.CLASSES = CLASSES + compatible_check.before_train_epoch(dummy_runner) + + dummy_runner.data_loader.dataset.CLASSES = None + compatible_check.before_train_epoch(dummy_runner) + + dataset_name, CLASSES = get_dataset_name_classes(config_mod['data']['val']) + if CLASSES is None: + CLASSES = DATASETS.get(dataset_name).CLASSES + dummy_runner.data_loader.dataset.CLASSES = CLASSES + compatible_check.before_val_epoch(dummy_runner) + dummy_runner.data_loader.dataset.CLASSES = None + compatible_check.before_val_epoch(dummy_runner) + + def test_config_build_detector(): """Test that all detection models defined in the configs can be initialized.""" @@ -51,6 +90,8 @@ def test_config_build_detector(): detector = build_detector(config_mod.model) assert detector is not None + _check_numclasscheckhook(detector, config_mod) + optimizer = build_optimizer(detector, config_mod.optimizer) assert isinstance(optimizer, torch.optim.Optimizer) @@ -62,6 +103,7 @@ def test_config_build_detector(): head_config = config_mod.model['roi_head'] _check_roi_head(head_config, detector.roi_head) + # else: # # for single stage detector # # detectors must have bbox head