Skip to content
Snippets Groups Projects
Unverified Commit 9e15a014 authored by Jiaqi Wang's avatar Jiaqi Wang Committed by GitHub
Browse files

add reppoints config and model without gn (#2058)

* add reppoints without gn

* fix
parent 93bed07b
No related branches found
No related tags found
No related merge requests found
...@@ -32,25 +32,26 @@ Another feature of this repo is the demonstration of an `anchor-free detector`, ...@@ -32,25 +32,26 @@ Another feature of this repo is the demonstration of an `anchor-free detector`,
The results on COCO 2017val are shown in the table below. The results on COCO 2017val are shown in the table below.
| Method | Backbone | Anchor | convert func | Lr schd | box AP | Download | | Method | Backbone | GN | Anchor | convert func | Lr schd | box AP | Download |
| :----: | :------: | :-------: | :------: | :-----: | :----: | :------: | | :----: | :------: | :-------: | :-------: | :------: | :-----: | :----: | :------: |
| BBox | R-50-FPN | single | - | 1x | 36.3|[model](https://drive.google.com/open?id=1TaVAFGZP2i7RwtlQjy3LBH1WI-YRH774) | | BBox | R-50-FPN | Y | single | - | 1x | 36.3|[model](https://drive.google.com/open?id=1TaVAFGZP2i7RwtlQjy3LBH1WI-YRH774) |
| BBox | R-50-FPN | none | - | 1x | 37.3| [model](https://drive.google.com/open?id=1hpfu-I7gtZnIb0NU2WvUvaZz_dm-THuZ) | | BBox | R-50-FPN | Y | none | - | 1x | 37.3| [model](https://drive.google.com/open?id=1hpfu-I7gtZnIb0NU2WvUvaZz_dm-THuZ) |
| RepPoints | R-50-FPN | none | partial MinMax | 1x | 38.1| [model](https://drive.google.com/open?id=11zFtdKH-QGz_zH7vlcIih6FQAjV84CWc) | | RepPoints | R-50-FPN | Y | none | partial MinMax | 1x | 38.1| [model](https://drive.google.com/open?id=11zFtdKH-QGz_zH7vlcIih6FQAjV84CWc) |
| RepPoints | R-50-FPN | none | MinMax | 1x | 38.2| [model](https://drive.google.com/open?id=1Cg9818dpkL-9qjmYdkhrY_BRiQFjV4xu) | | RepPoints | R-50-FPN | Y | none | MinMax | 1x | 38.2| [model](https://drive.google.com/open?id=1Cg9818dpkL-9qjmYdkhrY_BRiQFjV4xu) |
| RepPoints | R-50-FPN | none | moment | 1x | 38.2| [model](https://drive.google.com/open?id=1rQg-lE-5nuqO1bt6okeYkti4Q-EaBsu_) | | RepPoints | R-50-FPN | Y | none | moment | 1x | 38.2| [model](https://drive.google.com/open?id=1rQg-lE-5nuqO1bt6okeYkti4Q-EaBsu_) |
| RepPoints | R-50-FPN | none | moment | 2x | 38.6| [model](https://drive.google.com/open?id=1TfR-5geVviKhRoXL9JP6cG3fkN2itbBU) | | RepPoints | R-50-FPN | N | none | moment | 1x | 36.8| [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/reppoints/reppoints_moment_r50_no_gn_fpn_1x-66db098e.pth) |
| RepPoints | R-50-FPN | none | moment | 2x (ms train) | 40.8| [model](https://drive.google.com/open?id=1oaHTIaP51oB5HJ6GWV3WYK19lMm9iJO6) | | RepPoints | R-50-FPN | Y | none | moment | 2x | 38.6| [model](https://drive.google.com/open?id=1TfR-5geVviKhRoXL9JP6cG3fkN2itbBU) |
| RepPoints | R-50-FPN | none | moment | 2x (ms train&ms test) | 42.2| | | RepPoints | R-50-FPN | Y | none | moment | 2x (ms train) | 40.8| [model](https://drive.google.com/open?id=1oaHTIaP51oB5HJ6GWV3WYK19lMm9iJO6) |
| RepPoints | R-101-FPN | none | moment | 2x | 40.3| [model](https://drive.google.com/open?id=1BAmGeUQ_zVQi2u7rgOuPQem2EjXDLgWm) | | RepPoints | R-50-FPN | Y | none | moment | 2x (ms train&ms test) | 42.2| |
| RepPoints | R-101-FPN | none | moment | 2x (ms train) | 42.3| [model](https://drive.google.com/open?id=14Lf0p4fXElXaxFu8stk3hek3bY8tNENX) | | RepPoints | R-101-FPN | Y | none | moment | 2x | 40.3| [model](https://drive.google.com/open?id=1BAmGeUQ_zVQi2u7rgOuPQem2EjXDLgWm) |
| RepPoints | R-101-FPN | none | moment | 2x (ms train&ms test) | 44.1| | | RepPoints | R-101-FPN | Y | none | moment | 2x (ms train) | 42.3| [model](https://drive.google.com/open?id=14Lf0p4fXElXaxFu8stk3hek3bY8tNENX) |
| RepPoints | R-101-FPN-DCN | none | moment | 2x | 43.0| [model](https://drive.google.com/open?id=1hpptxpb4QtNuB-HnV5wHbDltPHhlYq4z) | | RepPoints | R-101-FPN | Y | none | moment | 2x (ms train&ms test) | 44.1| |
| RepPoints | R-101-FPN-DCN | none | moment | 2x (ms train) | 44.8| [model](https://drive.google.com/open?id=1fsTckK99HYjOURwcFeHfy5JRRtsCajfX) | | RepPoints | R-101-FPN-DCN | Y | none | moment | 2x | 43.0| [model](https://drive.google.com/open?id=1hpptxpb4QtNuB-HnV5wHbDltPHhlYq4z) |
| RepPoints | R-101-FPN-DCN | none | moment | 2x (ms train&ms test) | 46.4| | | RepPoints | R-101-FPN-DCN | Y | none | moment | 2x (ms train) | 44.8| [model](https://drive.google.com/open?id=1fsTckK99HYjOURwcFeHfy5JRRtsCajfX) |
| RepPoints | X-101-FPN-DCN | none | moment | 2x | 44.5| [model](https://drive.google.com/open?id=1Y8vqaqU88-FEqqwl6Zb9exD5O246yrMR) | | RepPoints | R-101-FPN-DCN | Y | none | moment | 2x (ms train&ms test) | 46.4| |
| RepPoints | X-101-FPN-DCN | none | moment | 2x (ms train) | 45.6| [model](https://drive.google.com/open?id=1nr9gcVWxzeakbfPC6ON9yvKOuLzj_RrJ) | | RepPoints | X-101-FPN-DCN | Y | none | moment | 2x | 44.5| [model](https://drive.google.com/open?id=1Y8vqaqU88-FEqqwl6Zb9exD5O246yrMR) |
| RepPoints | X-101-FPN-DCN | none | moment | 2x (ms train&ms test) | 46.8| | | RepPoints | X-101-FPN-DCN | Y | none | moment | 2x (ms train) | 45.6| [model](https://drive.google.com/open?id=1nr9gcVWxzeakbfPC6ON9yvKOuLzj_RrJ) |
| RepPoints | X-101-FPN-DCN | Y | none | moment | 2x (ms train&ms test) | 46.8| |
**Notes:** **Notes:**
......
# model settings
model = dict(
type='RepPointsDetector',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5),
bbox_head=dict(
type='RepPointsHead',
num_classes=81,
in_channels=256,
feat_channels=256,
point_feat_channels=256,
stacked_convs=3,
num_points=9,
gradient_mul=0.1,
point_strides=[8, 16, 32, 64, 128],
point_base_scale=4,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_init=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.5),
loss_bbox_refine=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0),
transform_method='moment'))
# training and testing settings
train_cfg = dict(
init=dict(
assigner=dict(type='PointAssigner', scale=4, pos_num=1),
allowed_border=-1,
pos_weight=-1,
debug=False),
refine=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False))
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/reppoints_moment_r50_no_gn_fpn_1x'
load_from = None
resume_from = None
auto_resume = True
workflow = [('train', 1)]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment