From 1c28e66903b07358ddf239a6c1027ce0d8d5ed4e Mon Sep 17 00:00:00 2001
From: Claudio Michaelis <claudio.michaelis@uni-tuebingen.de>
Date: Sat, 27 Jul 2019 16:11:17 +0200
Subject: [PATCH] Add cityscapes dataset (#1037)

* added cityscapes

* updated configs

* removed wip configs

* Add initial dataset instructions

* Add cityscapes readme

* Add explanation for lr scaling

* Ensure pep8 conformity

* Add CityscapesDataset to the registry

* add benchmark

* rename config, modify README.md

* fix typo

* fix typo in config

* modify INSTALL.md

Update information how to arrange cityscapes data.

* Add cityscapes class names
---
 INSTALL.md                                    |  11 +
 configs/cityscapes/README.md                  |  28 +++
 .../faster_rcnn_r50_fpn_1x_cityscapes.py      | 175 ++++++++++++++++
 .../mask_rcnn_r50_fpn_1x_cityscapes.py        | 189 ++++++++++++++++++
 mmdet/core/evaluation/class_names.py          |  10 +-
 mmdet/datasets/__init__.py                    |   8 +-
 mmdet/datasets/cityscapes.py                  |   9 +
 7 files changed, 426 insertions(+), 4 deletions(-)
 create mode 100644 configs/cityscapes/README.md
 create mode 100644 configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
 create mode 100644 configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py
 create mode 100644 mmdet/datasets/cityscapes.py

diff --git a/INSTALL.md b/INSTALL.md
index e48b1312..cc390d2b 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -66,10 +66,21 @@ mmdetection
 鈹�   鈹�   鈹溾攢鈹€ train2017
 鈹�   鈹�   鈹溾攢鈹€ val2017
 鈹�   鈹�   鈹溾攢鈹€ test2017
+鈹�   鈹溾攢鈹€ cityscapes
+鈹�   鈹�   鈹溾攢鈹€ annotations
+鈹�   鈹�   鈹溾攢鈹€ train
+鈹�   鈹�   鈹溾攢鈹€ val
 鈹�   鈹溾攢鈹€ VOCdevkit
 鈹�   鈹�   鈹溾攢鈹€ VOC2007
 鈹�   鈹�   鈹溾攢鈹€ VOC2012
 
+```
+The cityscapes annotations have to be converted into the coco format using the [cityscapesScripts](https://github.com/mcordts/cityscapesScripts) toolbox.
+We plan to provide an easy to use conversion script. For the moment we recommend following the instructions provided in the 
+[maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark/tree/master/maskrcnn_benchmark/data) toolbox. When using this script all images have to be moved into the same folder. On linux systems this can e.g. be done for the train images with:
+```shell
+cd data/cityscapes/
+mv train/*/* train/
 ```
 
 ### Scripts
diff --git a/configs/cityscapes/README.md b/configs/cityscapes/README.md
new file mode 100644
index 00000000..f63f8e84
--- /dev/null
+++ b/configs/cityscapes/README.md
@@ -0,0 +1,28 @@
+## Common settings
+
+- All baselines were trained using 8 GPU with a batch size of 8 (1 images per GPU) using the [linear scaling rule](https://arxiv.org/abs/1706.02677) to scale the learning rate. 
+- All models were trained on `cityscapes_train`, and tested on `cityscapes_val`.
+- 1x training schedule indicates 64 epochs which corresponds to slightly less than the 24k iterations reported in the original schedule from the [Mask R-CNN paper](https://arxiv.org/abs/1703.06870)
+- All pytorch-style pretrained backbones on ImageNet are from PyTorch model zoo.
+
+
+## Baselines
+
+Download links and more models with different backbones and training schemes will be added to the model zoo.
+
+
+### Faster R-CNN
+
+|    Backbone     |  Style  | Lr schd | Scale    | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
+| :-------------: | :-----: | :-----: | :---:    | :------: | :-----------------: | :------------: | :----: | :------: |
+|    R-50-FPN     | pytorch |   1x    | 800-1024 | 4.9      | 0.345               | 8.8            | 36.0   | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/faster_rcnn_r50_fpn_1x_city_20190727-7b9c0534.pth) |
+
+### Mask R-CNN
+
+|    Backbone     |  Style  | Lr schd | Scale    | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download |
+| :-------------: | :-----: | :-----: | :------: | :------: | :-----------------: | :------------: | :----: | :-----: | :------: |
+|    R-50-FPN     | pytorch |   1x    | 800-1024 | 4.9      | 0.609               | 2.5            | 37.4  |  32.5   | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/mask_rcnn_r50_fpn_1x_city_20190727-9b3c56a5.pth) |
+
+**Notes:**
+- In the original paper, the mask AP of Mask R-CNN R-50-FPN is 31.5.
+
diff --git a/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py b/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
new file mode 100644
index 00000000..0ccacd20
--- /dev/null
+++ b/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
@@ -0,0 +1,175 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=9,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=1,
+    workers_per_gpu=2,
+    train=dict(
+        type='RepeatDataset',  # to avoid reloading datasets frequently
+        times=8,
+        dataset=dict(
+            type=dataset_type,
+            ann_file=data_root +
+            'annotations/instancesonly_filtered_gtFine_train.json',
+            img_prefix=data_root + 'train/',
+            img_scale=[(2048, 800), (2048, 1024)],
+            img_norm_cfg=img_norm_cfg,
+            multiscale_mode='range',
+            size_divisor=32,
+            flip_ratio=0.5,
+            with_mask=False,
+            with_crowd=True,
+            with_label=True)),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root +
+        'annotations/instancesonly_filtered_gtFine_val.json',
+        img_prefix=data_root + 'val/',
+        img_scale=(2048, 1024),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root +
+        'annotations/instancesonly_filtered_gtFine_val.json',
+        img_prefix=data_root + 'val/',
+        img_scale=(2048, 1024),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+# lr is set for a batch size of 8
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[6])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=100,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 8  # actual epoch = 8 * 8 = 64
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py b/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py
new file mode 100644
index 00000000..85f32f7e
--- /dev/null
+++ b/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py
@@ -0,0 +1,189 @@
+# model settings
+model = dict(
+    type='MaskRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=9,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=9,
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=1,
+    workers_per_gpu=2,
+    train=dict(
+        type='RepeatDataset',  # to avoid reloading datasets frequently
+        times=8,
+        dataset=dict(
+            type=dataset_type,
+            ann_file=data_root +
+            'annotations/instancesonly_filtered_gtFine_train.json',
+            img_prefix=data_root + 'train/',
+            img_scale=[(2048, 800), (2048, 1024)],
+            img_norm_cfg=img_norm_cfg,
+            multiscale_mode='range',
+            size_divisor=32,
+            flip_ratio=0.5,
+            with_mask=True,
+            with_crowd=True,
+            with_label=True)),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root +
+        'annotations/instancesonly_filtered_gtFine_val.json',
+        img_prefix=data_root + 'val/',
+        img_scale=(2048, 1024),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root +
+        'annotations/instancesonly_filtered_gtFine_val.json',
+        img_prefix=data_root + 'val/',
+        img_scale=(2048, 1024),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+# lr is set for a batch size of 8
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[6])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=100,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 8  # actual epoch = 8 * 8 = 64
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py
index 87fb2399..78427734 100644
--- a/mmdet/core/evaluation/class_names.py
+++ b/mmdet/core/evaluation/class_names.py
@@ -82,12 +82,20 @@ def coco_classes():
     ]
 
 
+def cityscapes_classes():
+    return [
+        'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+        'bicycle'
+    ]
+
+
 dataset_aliases = {
     'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
     'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
     'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
     'coco': ['coco', 'mscoco', 'ms_coco'],
-    'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace']
+    'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace'],
+    'cityscapes': ['cityscapes']
 }
 
 
diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
index ab3e5490..786b0593 100644
--- a/mmdet/datasets/__init__.py
+++ b/mmdet/datasets/__init__.py
@@ -1,6 +1,7 @@
 from .custom import CustomDataset
 from .xml_style import XMLDataset
 from .coco import CocoDataset
+from .cityscapes import CityscapesDataset
 from .voc import VOCDataset
 from .wider_face import WIDERFaceDataset
 from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
@@ -11,8 +12,9 @@ from .registry import DATASETS
 from .builder import build_dataset
 
 __all__ = [
-    'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
-    'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
-    'show_ann', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation',
+    'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
+    'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
+    'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
+    'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation',
     'WIDERFaceDataset', 'DATASETS', 'build_dataset'
 ]
diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py
new file mode 100644
index 00000000..fcbd43c1
--- /dev/null
+++ b/mmdet/datasets/cityscapes.py
@@ -0,0 +1,9 @@
+from .coco import CocoDataset
+from .registry import DATASETS
+
+
+@DATASETS.register_module
+class CityscapesDataset(CocoDataset):
+
+    CLASSES = ('person', 'rider', 'car', 'truck', 'bus',
+               'train', 'motorcycle', 'bicycle')
-- 
GitLab