Skip to content
Snippets Groups Projects
Unverified Commit dab4ac1b authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support RegNetX models (#2710)

* Add RegNet, docs/refactor TBD

* Support RegNet

* Add RegNets

* Modify docs and refactor code

* fix pre-train bug

* Performance ready

* Use public model weights

* Performance & settings done

* Fix unittest bug

* Try another way of test config

* Add unittest and refactor

* Reformat

* Finish refactor regnet

* merge tests

* fix conflicts

* Fix resnext CI bug

* rename parameter

* Add ref link & clean _delete_

* tweaks of doc

* try resnext doctest

* add expansion in block

* remove doctest of resnext
parent 0367dd9c
No related branches found
No related tags found
No related merge requests found
Showing
with 682 additions and 27 deletions
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.7.9
rev: 3.8.0
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
......
......@@ -49,31 +49,31 @@ A comparison between v1.x and v2.0 codebases can be found in [compatibility.md](
Supported methods and backbones are shown in the below table.
Results and models are available in the [model zoo](docs/model_zoo.md).
| | ResNet | ResNeXt | SENet | VGG | HRNet |
|--------------------|:--------:|:--------:|:--------:|:--------:|:-----:|
| RPN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Fast R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Faster R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Mask R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Cascade R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Cascade Mask R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| SSD | ✗ | ✗ | ✗ | ✓ | ✗ |
| RetinaNet | ✓ | ✓ | ☐ | ✗ | ✓ |
| GHM | ✓ | ✓ | ☐ | ✗ | ✓ |
| Mask Scoring R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Double-Head R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Grid R-CNN (Plus) | ✓ | ✓ | ☐ | ✗ | ✓ |
| Hybrid Task Cascade| ✓ | ✓ | ☐ | ✗ | ✓ |
| Libra R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ |
| Guided Anchoring | ✓ | ✓ | ☐ | ✗ | ✓ |
| FCOS | ✓ | ✓ | ☐ | ✗ | ✓ |
| RepPoints | ✓ | ✓ | ☐ | ✗ | ✓ |
| Foveabox | ✓ | ✓ | ☐ | ✗ | ✓ |
| FreeAnchor | ✓ | ✓ | ☐ | ✗ | ✓ |
| NAS-FPN | ✓ | ✓ | ☐ | ✗ | ✓ |
| ATSS | ✓ | ✓ | ☐ | ✗ | ✓ |
| FSAF | ✓ | ✓ | ☐ | ✗ | ✓ |
| PAFPN | ✓ | ✓ | ☐ | ✗ | ✓ |
| | ResNet | ResNeXt | SENet | VGG | HRNet | RegNetX | Res2Net |
|--------------------|:--------:|:--------:|:--------:|:--------:|:-----:|:--------:|:-----:|
| RPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Fast R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Faster R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ✓ | ✓ |
| Mask R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ✓ | ✓ |
| Cascade R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ✓ |
| Cascade Mask R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ✓ |
| SSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ |
| RetinaNet | ✓ | ✓ | ☐ | ✗ | ✓ | ✓ | ☐ |
| GHM | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Mask Scoring R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Double-Head R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Grid R-CNN (Plus) | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Hybrid Task Cascade| ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ✓ |
| Libra R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Guided Anchoring | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| FCOS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| RepPoints | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Foveabox | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| FreeAnchor | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| NAS-FPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| ATSS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| FSAF | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| PAFPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
Other features
- [x] [CARAFE](configs/carafe/README.md)
......
# Designing Network Design Spaces
## Introduction
We implement RegNetX and RegNetY models in detection systems and provide their first results on Mask R-CNN, Faster R-CNN and RetinaNet.
The pre-trained modles are converted from [model zoo of pycls](https://github.com/facebookresearch/pycls/blob/master/MODEL_ZOO.md).
```
@article{radosavovic2020designing,
title={Designing Network Design Spaces},
author={Ilija Radosavovic and Raj Prateek Kosaraju and Ross Girshick and Kaiming He and Piotr Dollár},
year={2020},
eprint={2003.13678},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Usage
To use a regnet model, there are two steps to do:
1. Convert the model to ResNet-style supported by MMDetection
2. Modify backbone and neck in config accordingly
### Convert model
We already prepare models of FLOPs from 800M to 12G in our model zoo.
For more general usage, we also provide script `regnet2mmdet.py` in the tools directory to convert the key of models pretrained by [pycls](https://github.com/facebookresearch/pycls/) to
ResNet-style checkpoints used in MMDetection.
```bash
python -u tools/regnet2mmdet.py ${PRETRAIN_PATH} ${STORE_PATH}
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
### Modify config
The users can modify the config's `depth` of backbone and corresponding keys in `arch` according to the configs in the [pycls model zoo](https://github.com/facebookresearch/pycls/blob/master/MODEL_ZOO.md).
The parameter `in_channels` in FPN can be found in the Figure 15 & 16 of the paper (`wi` in the legend).
This directory already provides some configs with their performance, using RegNetX from 800MF to 12GF level.
For other pre-trained models or self-implemented regnet models, the users are responsible to check these parameters by themselves.
**Note**: Although Fig. 15 & 16 also provide `w0`, `wa`, `wm`, `group_w`, and `bot_mul` for `arch`, they are quantized thus inaccurate, using them sometimes produces different backbone that does not match the key in the pre-trained model.
## Results
### Mask R-CNN
| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | mask AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :-----: | :------: |
| [R-50-FPN](../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py)| pytorch | 1x | 4.4 | 12.0 | 38.2 | 34.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_1x_coco/mask_rcnn_r50_fpn_1x_coco_20200205-d4b0c5d6.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_1x_coco/mask_rcnn_r50_fpn_1x_coco_20200205_050542.log.json) |
|[RegNetX-3.2GF-FPN](./mask_rcnn_regnetx-3GF_fpn_1x_coco.py)| pytorch | 1x |5.0 ||40.3|36.6| |
|[RegNetX-4.0GF-FPN](./mask_rcnn_regnetx-4GF_fpn_1x_coco.py)| pytorch | 1x |5.5||41.5|37.4| |
| [R-101-FPN](../mask_rcnn/mask_rcnn_r101_fpn_1x_coco.py)| pytorch | 1x | 6.4 | 10.3 | 40.0 | 36.1 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r101_fpn_1x_coco/mask_rcnn_r101_fpn_1x_coco_20200204-1efe0ed5.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r101_fpn_1x_coco/mask_rcnn_r101_fpn_1x_coco_20200204_144809.log.json) |
|[RegNetX-6.4GF-FPN](./mask_rcnn_regnetx-6GF_fpn_1x_coco.py)| pytorch | 1x |6.1 ||41.0|37.1| |
| [X-101-32x4d-FPN](../mask_rcnn/mask_rcnn_x101_32x4d_fpn_1x_coco.py) | pytorch | 1x | 7.6 | 9.4 | 41.9 | 37.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_x101_32x4d_fpn_1x_coco/mask_rcnn_x101_32x4d_fpn_1x_coco_20200205-478d0b67.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_x101_32x4d_fpn_1x_coco/mask_rcnn_x101_32x4d_fpn_1x_coco_20200205_034906.log.json) |
|[RegNetX-8.0GF-FPN](./mask_rcnn_regnetx-8GF_fpn_1x_coco.py)| pytorch | 1x |6.4 ||41.7|37.5| |
|[RegNetX-12GF-FPN](./mask_rcnn_regnetx-12GF_fpn_1x_coco.py)| pytorch | 1x |7.4 ||42.2|38| |
### Faster R-CNN
| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :------: |
| [R-50-FPN](../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py)| pytorch | 1x | 4.0 | 18.2 | 37.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130_204655.log.json) |
|[RegNetX-3.2GF-FPN*](./faster_rcnn_regnetx-3GF_fpn_mstrain_1x_coco.py)| pytorch | 1x | 4.5||39.9| |
### RetinaNet
| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :------: |
| [R-50-FPN](../retinanet/retinanet_r50_fpn_1x_coco.py) | pytorch | 1x | 3.8 | 16.6 | 36.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130_002941.log.json) |
|[RegNetX-800MF-FPN](./retinanet_regnetx-800MF_fpn_mstrain_1x_coco.py)| pytorch | 1x |2.5||35.6| |
|[RegNetX-1.6GF-FPN](./retinanet_regnetx-1GF_fpn_mstrain_1x_coco.py)| pytorch | 1x |3.3||37.3| |
|[RegNetX-3.2GF-FPN](./retinanet_regnetx-3GF_fpn_mstrain_1x_coco.py)| pytorch | 1x |4.2 ||39.1| |
### Notice
1. The models are trained using a different weight decay, i.e., `weight_decay=5e-5` according to the setting in ImageNet training. This brings improvement of at least 0.7 AP absolute but does not improve the model using ResNet-50.
2. RetinaNets using RegNets are trained with learning rate 0.02 with gradient clip. We find that using learning rate 0.02 could improve the results by at least 0.7 AP absolute and gradient clip is necessary to stabilize the training.
However, this does not improve the performance of ResNet-50-FPN RetinaNet.
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
lr_config = dict(step=[28, 34])
total_epochs = 36
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
lr_config = dict(step=[28, 34])
total_epochs = 36
_base_ = './mask_rcnn_regnetx-3GF_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://regnetx_12gf',
backbone=dict(
type='RegNet',
arch='regnetx_12gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[224, 448, 896, 2240],
out_channels=256,
num_outs=5))
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
# RegNet convert images to float32 directly after loading in PyCls
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
# RegNet convert images to float32 directly after loading in PyCls
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
# RegNet convert images to float32 directly after loading in PyCls
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
lr_config = dict(step=[28, 34])
total_epochs = 36
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
_base_ = './mask_rcnn_regnetx-3GF_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://regnetx_4.0gf',
backbone=dict(
type='RegNet',
arch='regnetx_4.0gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[80, 240, 560, 1360],
out_channels=256,
num_outs=5))
_base_ = './mask_rcnn_regnetx-3GF_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://regnetx_6.4gf',
backbone=dict(
type='RegNet',
arch='regnetx_6.4gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[168, 392, 784, 1624],
out_channels=256,
num_outs=5))
_base_ = './mask_rcnn_regnetx-3GF_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://regnetx_8.0gf',
backbone=dict(
type='RegNet',
arch='regnetx_8.0gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[80, 240, 720, 1920],
out_channels=256,
num_outs=5))
_base_ = './retinanet_r50_regnetx-3GF_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://regnetx_1.6gf',
backbone=dict(
type='RegNet',
arch='regnetx_1.6gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[72, 168, 408, 912],
out_channels=256,
num_outs=5))
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='open-mmlab://regnetx_3.2gf',
backbone=dict(
_delete_=True,
type='RegNet',
arch='regnetx_3.2gf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[96, 192, 432, 1008],
out_channels=256,
num_outs=5))
img_norm_cfg = dict(
# The mean and std is used in PyCls when training RegNets
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00005)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
_base_ = './retinanet_r50_regnetx-3GF_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://regnetx_800mf',
backbone=dict(
type='RegNet',
arch='regnetx_800mf',
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[64, 128, 288, 672],
out_channels=256,
num_outs=5))
......@@ -14,6 +14,7 @@ Welcome to MMDetection's documentation!
tutorials/new_modules.md
compatibility.md
changelog.md
projects.md
api.rst
......
......@@ -123,6 +123,12 @@ Please refer to [ATSS](https://github.com/open-mmlab/mmdetection/blob/master/con
### FSAF
Please refer to [FSAF](https://github.com/open-mmlab/mmdetection/blob/master/configs/fsaf) for details.
### RegNetX
Please refer to [RegNet](https://github.com/open-mmlab/mmdetection/blob/master/configs/regnet) for details.
### Res2Net
Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/configs/res2net) for details.
### Other datasets
We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
......
......@@ -25,4 +25,4 @@ Methods already supported and maintained by MMDetection are not listed.
- SOLOv2: Dynamic, Faster and Stronger. [[paper]](https://arxiv.org/abs/2003.10152)[[github]](https://github.com/WXinlong/SOLO)
- Dense Peppoints: Representing Visual Objects with Dense Point Sets. [[paper]](https://arxiv.org/abs/1912.11473)[[github]](https://github.com/justimyhxu/Dense-RepPoints)
- IterDet: Iterative Scheme for Object Detection in Crowded Environments. [[paper]](https://arxiv.org/abs/2005.05708)[[github]](https://github.com/saic-vul/iterdet)
- Cross-Iteration Batch Normalization [[paper]](https://arxiv.org/abs/2002.05712)[[github]](https://github.com/Howal/Cross-iterationBatchNorm)
- Cross-Iteration Batch Normalization. [[paper]](https://arxiv.org/abs/2002.05712)[[github]](https://github.com/Howal/Cross-iterationBatchNorm)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment