Skip to content
Snippets Groups Projects
Unverified Commit 1b7fb930 authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

add CompatibleCheckHook (#4508)

* add CompatibleCheckHook

* check val

* add docstr

* move hook to default runtime

* add unitest

* fix a typo

* fix unitest

* update docs

* add check for CLASSES is None

* fix doc
parent e3e6ca76
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,8 @@ log_config = dict(
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
......
......@@ -20,7 +20,7 @@ model = dict(
num_outs=5),
bbox_head=dict(
type='GARetinaHead',
num_classes=81,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
......@@ -144,8 +144,7 @@ data = dict(
evaluation = dict(interval=1, metric='bbox')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
......
......@@ -483,4 +483,5 @@ data = dict(
- Before MMDetection v2.5.0, the dataset will filter out the empty GT images automatically if the classes are set and there is no way to disable that through config. This is an undesirable behavior and introduces confusion because if the classes are not set, the dataset only filter the empty GT images when `filter_empty_gt=True` and `test_mode=False`. After MMDetection v2.5.0, we decouple the image filtering process and the classes modification, i.e., the dataset will only filter empty GT images when `filter_empty_gt=True` and `test_mode=False`, no matter whether the classes are set. Thus, setting the classes only influences the annotations of classes used for training and users could decide whether to filter empty GT images by themselves.
- Since the middle format only has box labels and does not contain the class names, when using `CustomDataset`, users cannot filter out the empty GT images through configs but only do this offline.
- Please remember to modify the `num_classes` in the head when specifying `classes` in dataset. We implemented [NumClassCheckHook](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/utils.py) to check whether the numbers are consistent since v2.9.0(after PR#4508).
- The features for setting dataset classes and dataset filtering will be refactored to be more user-friendly in v2.8.0 or v2.9.0 (depends on the progress).
......@@ -264,10 +264,14 @@ By default the hook's priority is set as `NORMAL` during registration.
If the hook is already implemented in MMCV, you can directly modify the config to use the hook as below
#### 4. Example: `NumClassCheckHook`
We implement a customized hook named [NumClassCheckHook](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/utils.py) to check whether the `num_classes` in head matches the length of `CLASSSES` in `dataset`.
We set it in [default_runtime.py](https://github.com/open-mmlab/mmdetection/blob/master/configs/_base_/default_runtime.py).
```python
custom_hooks = [
dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL')
]
custom_hooks = [dict(type='NumClassCheckHook')]
```
### Modify default runtime hooks
......
import copy
import warnings
from mmcv.cnn import VGG
from mmcv.runner.hooks import HOOKS, Hook
from mmdet.models.dense_heads import GARPNHead, RPNHead
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
def replace_ImageToTensor(pipelines):
"""Replace the ImageToTensor transform in a data pipeline to
......@@ -98,3 +104,50 @@ def get_loading_pipeline(pipeline):
'The data pipeline in your config file must include ' \
'loading image and annotations related pipeline.'
return loading_pipeline_cfg
@HOOKS.register_module()
class NumClassCheckHook(Hook):
def _check_head(self, runner):
"""Check whether the `num_classes` in head matches the length of
`CLASSSES` in `dataset`.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
model = runner.model
dataset = runner.data_loader.dataset
if dataset.CLASSES is None:
runner.logger.warning(
f'Please set `CLASSES` '
f'in the {dataset.__class__.__name__} and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
for name, module in model.named_modules():
if hasattr(module, 'num_classes') and not isinstance(
module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)):
assert module.num_classes == len(dataset.CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `CLASSES` '
f'{len(dataset.CLASSES)}) in '
f'{dataset.__class__.__name__}')
def before_train_epoch(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner)
def before_val_epoch(self, runner):
"""Check whether the dataset in val epoch is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner)
from os.path import dirname, exists, join, relpath
from unittest.mock import Mock
import pytest
import torch
from mmcv.runner import build_optimizer
from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.utils import NumClassCheckHook
def _get_config_directory():
......@@ -22,6 +25,42 @@ def _get_config_directory():
return config_dpath
def _check_numclasscheckhook(detector, config_mod):
dummy_runner = Mock()
dummy_runner.model = detector
def get_dataset_name_classes(dataset):
# deal with `RepeatDataset`,`ConcatDataset`,`ClassBalancedDataset`..
if isinstance(dataset, (list, tuple)):
dataset = dataset[0]
while ('dataset' in dataset):
dataset = dataset['dataset']
# ConcatDataset
if isinstance(dataset, (list, tuple)):
dataset = dataset[0]
return dataset['type'], dataset.get('classes', None)
compatible_check = NumClassCheckHook()
dataset_name, CLASSES = get_dataset_name_classes(
config_mod['data']['train'])
if CLASSES is None:
CLASSES = DATASETS.get(dataset_name).CLASSES
dummy_runner.data_loader.dataset.CLASSES = CLASSES
compatible_check.before_train_epoch(dummy_runner)
dummy_runner.data_loader.dataset.CLASSES = None
compatible_check.before_train_epoch(dummy_runner)
dataset_name, CLASSES = get_dataset_name_classes(config_mod['data']['val'])
if CLASSES is None:
CLASSES = DATASETS.get(dataset_name).CLASSES
dummy_runner.data_loader.dataset.CLASSES = CLASSES
compatible_check.before_val_epoch(dummy_runner)
dummy_runner.data_loader.dataset.CLASSES = None
compatible_check.before_val_epoch(dummy_runner)
def test_config_build_detector():
"""Test that all detection models defined in the configs can be
initialized."""
......@@ -51,6 +90,8 @@ def test_config_build_detector():
detector = build_detector(config_mod.model)
assert detector is not None
_check_numclasscheckhook(detector, config_mod)
optimizer = build_optimizer(detector, config_mod.optimizer)
assert isinstance(optimizer, torch.optim.Optimizer)
......@@ -62,6 +103,7 @@ def test_config_build_detector():
head_config = config_mod.model['roi_head']
_check_roi_head(head_config, detector.roi_head)
# else:
# # for single stage detector
# # detectors must have bbox head
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment