Skip to content
Snippets Groups Projects
Commit a9e21cf7 authored by myownskyW7's avatar myownskyW7 Committed by Kai Chen
Browse files

Support models without FPN (#133)

* add two stage w/o neck and w/ upperneck

* add rpn r50 c4

* update c4 configs

* fix

* config update

* update config

* minor update

* mask rcnn support c4 train and test

* lr fix

* cascade support upper_neck

* add cascade c4 config

* update config

* update

* update res_layer to new interface

* refactoring

* c4 configs update

* refactoring

* update rpn_c4 config

* rename upper_neck as shared_head

* update

* update configs

* update

* update c4 configs

* update according to commits

* update
parent 90096804
No related branches found
No related tags found
No related merge requests found
Showing
with 1179 additions and 50 deletions
# model settings
model = dict(
type='CascadeRCNN',
num_stages=3,
pretrained='open-mmlab://resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
normalize=dict(type='BN', frozen=True),
norm_eval=True,
style='caffe'),
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
normalize=dict(type='BN', frozen=True),
norm_eval=True),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_scales=[2, 4, 8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[16],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=1024,
featmap_strides=[16]),
bbox_head=[
dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=True),
dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1],
reg_class_agnostic=True),
dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067],
reg_class_agnostic=True)
],
mask_roi_extractor=None,
mask_head=dict(
type='FCNMaskHead',
num_convs=0,
in_channels=2048,
conv_out_channels=256,
num_classes=81))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False)
],
stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=12000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5),
keep_all_stages=False)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=True,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=True,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cascade_mask_rcnn_r50_c4_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='CascadeRCNN',
num_stages=3,
pretrained='open-mmlab://resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
normalize=dict(type='BN', frozen=True),
norm_eval=True,
style='caffe'),
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
normalize=dict(type='BN', frozen=True),
norm_eval=True),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_scales=[2, 4, 8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[16],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=1024,
featmap_strides=[16]),
bbox_head=[
dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=True),
dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1],
reg_class_agnostic=True),
dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067],
reg_class_agnostic=True)
])
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False)
],
stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=12000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
keep_all_stages=False)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=True,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=True,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cascade_rcnn_r50_c4_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='FastRCNN',
pretrained='open-mmlab://resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
normalize=dict(type='BN', frozen=True),
norm_eval=True,
style='caffe'),
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
normalize=dict(type='BN', frozen=True),
norm_eval=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=1024,
featmap_strides=[16]),
bbox_head=dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
proposal_file=data_root + 'proposals/rpn_r50_c4_1x_train2017.pkl',
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
proposal_file=data_root + 'proposals/rpn_r50_c4_1x_val2017.pkl',
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
proposal_file=data_root + 'proposals/rpn_r50_c4_1x_val2017.pkl',
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fast_rcnn_r50_c4_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='FasterRCNN',
pretrained='open-mmlab://resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
normalize=dict(type='BN', frozen=True),
norm_eval=True,
style='caffe'),
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
normalize=dict(type='BN', frozen=True),
norm_eval=True),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_scales=[2, 4, 8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[16],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=1024,
featmap_strides=[16]),
bbox_head=dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=12000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_r50_c4_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='MaskRCNN',
pretrained='open-mmlab://resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
normalize=dict(type='BN', frozen=True),
norm_eval=True,
style='caffe'),
shared_head=dict(
type='ResLayer',
depth=50,
stage=3,
stride=2,
dilation=1,
style='caffe',
normalize=dict(type='BN', frozen=True),
norm_eval=True),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_scales=[2, 4, 8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[16],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=1024,
featmap_strides=[16]),
bbox_head=dict(
type='BBoxHead',
with_avg_pool=True,
roi_feat_size=7,
in_channels=2048,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False),
mask_roi_extractor=None,
mask_head=dict(
type='FCNMaskHead',
num_convs=0,
in_channels=2048,
conv_out_channels=256,
num_classes=81))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=14,
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=12000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=True,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mask_rcnn_r50_c4_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='RPN',
pretrained='open-mmlab://resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=3,
strides=(1, 2, 2),
dilations=(1, 1, 1),
out_indices=(2, ),
frozen_stages=1,
normalize=dict(type='BN', frozen=True),
norm_eval=True,
style='caffe'),
neck=None,
rpn_head=dict(
type='RPNHead',
in_channels=1024,
feat_channels=1024,
anchor_scales=[2, 4, 8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[16],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=12000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=False,
with_label=False),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=False,
with_label=False),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
# runner configs
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/rpn_r50_c4_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
......@@ -2,15 +2,17 @@ from .backbones import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .roi_extractors import * # noqa: F401,F403
from .anchor_heads import * # noqa: F401,F403
from .shared_heads import * # noqa: F401,F403
from .bbox_heads import * # noqa: F401,F403
from .mask_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
DETECTORS)
from .builder import (build_backbone, build_neck, build_roi_extractor,
build_head, build_detector)
build_shared_head, build_head, build_detector)
__all__ = [
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'HEADS', 'DETECTORS',
'build_backbone', 'build_neck', 'build_roi_extractor', 'build_head',
'build_detector'
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS',
'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
'build_shared_head', 'build_head', 'build_detector'
]
from .resnet import ResNet
from .resnet import ResNet, make_res_layer
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
__all__ = ['ResNet', 'ResNeXt', 'SSDVGG']
__all__ = ['ResNet', 'make_res_layer', 'ResNeXt', 'SSDVGG']
......@@ -440,10 +440,7 @@ class ResNet(nn.Module):
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
......
import mmcv
from torch import nn
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
DETECTORS)
def _build_module(cfg, registry, default_args):
......@@ -43,6 +44,10 @@ def build_roi_extractor(cfg):
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
return build(cfg, SHARED_HEADS)
def build_head(cfg):
return build(cfg, HEADS)
......
......@@ -21,6 +21,10 @@ class BaseDetector(nn.Module):
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None
@property
def with_shared_head(self):
return hasattr(self, 'shared_head') and self.shared_head is not None
@property
def with_bbox(self):
return hasattr(self, 'bbox_head') and self.bbox_head is not None
......
......@@ -18,6 +18,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
num_stages,
backbone,
neck=None,
shared_head=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
......@@ -35,12 +36,13 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
if neck is not None:
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_head(rpn_head)
if shared_head is not None:
self.shared_head = builder.build_shared_head(shared_head)
if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList()
self.bbox_head = nn.ModuleList()
......@@ -57,19 +59,26 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
self.bbox_head.append(builder.build_head(head))
if mask_head is not None:
self.mask_roi_extractor = nn.ModuleList()
self.mask_head = nn.ModuleList()
if not isinstance(mask_roi_extractor, list):
mask_roi_extractor = [
mask_roi_extractor for _ in range(num_stages)
]
if not isinstance(mask_head, list):
mask_head = [mask_head for _ in range(num_stages)]
assert len(mask_roi_extractor) == len(mask_head) == self.num_stages
for roi_extractor, head in zip(mask_roi_extractor, mask_head):
self.mask_roi_extractor.append(
builder.build_roi_extractor(roi_extractor))
assert len(mask_head) == self.num_stages
for head in mask_head:
self.mask_head.append(builder.build_head(head))
if mask_roi_extractor is not None:
self.share_roi_extractor = False
self.mask_roi_extractor = nn.ModuleList()
if not isinstance(mask_roi_extractor, list):
mask_roi_extractor = [
mask_roi_extractor for _ in range(num_stages)
]
assert len(mask_roi_extractor) == self.num_stages
for roi_extractor in mask_roi_extractor:
self.mask_roi_extractor.append(
builder.build_roi_extractor(roi_extractor))
else:
self.share_roi_extractor = True
self.mask_roi_extractor = self.bbox_roi_extractor
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -91,12 +100,15 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_shared_head:
self.shared_head.init_weights(pretrained=pretrained)
for i in range(self.num_stages):
if self.with_bbox:
self.bbox_roi_extractor[i].init_weights()
self.bbox_head[i].init_weights()
if self.with_mask:
self.mask_roi_extractor[i].init_weights()
if not self.share_roi_extractor:
self.mask_roi_extractor[i].init_weights()
self.mask_head[i].init_weights()
def extract_feat(self, img):
......@@ -164,23 +176,45 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
rois = bbox2roi([res.bboxes for res in sampling_results])
bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = bbox_head(bbox_feats)
bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes,
gt_labels, rcnn_train_cfg)
loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets)
for name, value in loss_bbox.items():
losses['s{}.{}'.format(i, name)] = (value * lw if
'loss' in name else value)
losses['s{}.{}'.format(
i, name)] = (value * lw if 'loss' in name else value)
# mask head forward and loss
if self.with_mask:
mask_roi_extractor = self.mask_roi_extractor[i]
if not self.share_roi_extractor:
mask_roi_extractor = self.mask_roi_extractor[i]
pos_rois = bbox2roi(
[res.pos_bboxes for res in sampling_results])
mask_feats = mask_roi_extractor(
x[:mask_roi_extractor.num_inputs], pos_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
else:
# reuse positive bbox feats
pos_inds = []
device = bbox_feats.device
for res in sampling_results:
pos_inds.append(
torch.ones(
res.pos_bboxes.shape[0],
device=device,
dtype=torch.uint8))
pos_inds.append(
torch.zeros(
res.neg_bboxes.shape[0],
device=device,
dtype=torch.uint8))
pos_inds = torch.cat(pos_inds)
mask_feats = bbox_feats[pos_inds]
mask_head = self.mask_head[i]
pos_rois = bbox2roi(
[res.pos_bboxes for res in sampling_results])
mask_feats = mask_roi_extractor(
x[:mask_roi_extractor.num_inputs], pos_rois)
mask_pred = mask_head(mask_feats)
mask_targets = mask_head.get_target(sampling_results, gt_masks,
rcnn_train_cfg)
......@@ -188,9 +222,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
[res.pos_gt_labels for res in sampling_results])
loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
for name, value in loss_mask.items():
losses['s{}.{}'.format(i, name)] = (value * lw
if 'loss' in name else
value)
losses['s{}.{}'.format(
i, name)] = (value * lw if 'loss' in name else value)
# refine bboxes
if i < self.num_stages - 1:
......@@ -224,6 +257,9 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
bbox_feats = bbox_roi_extractor(
x[:len(bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = bbox_head(bbox_feats)
ms_scores.append(cls_score)
......@@ -254,6 +290,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
mask_feats = mask_roi_extractor(
x[:len(mask_roi_extractor.featmap_strides)],
mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats, i)
mask_pred = mask_head(mask_feats)
segm_result = mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, rcnn_test_cfg,
......@@ -292,6 +330,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
mask_roi_extractor = self.mask_roi_extractor[i]
mask_feats = mask_roi_extractor(
x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head[i](mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks,
......
......@@ -7,17 +7,19 @@ class FastRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
mask_roi_extractor=None,
mask_head=None,
pretrained=None):
super(FastRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
......
......@@ -7,16 +7,18 @@ class FasterRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
......
......@@ -7,7 +7,6 @@ class MaskRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
......@@ -15,10 +14,13 @@ class MaskRCNN(TwoStageDetector):
mask_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
pretrained=None):
super(MaskRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
......
......@@ -37,6 +37,8 @@ class BBoxTestMixin(object):
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
roi_feats = self.shared_head(roi_feats)
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
......@@ -65,6 +67,8 @@ class BBoxTestMixin(object):
# recompute feature maps to save GPU memory
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
roi_feats = self.shared_head(roi_feats)
cls_score, bbox_pred = self.bbox_head(roi_feats)
bboxes, scores = self.bbox_head.get_det_bboxes(
rois,
......@@ -106,6 +110,8 @@ class MaskTestMixin(object):
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
......@@ -127,6 +133,8 @@ class MaskTestMixin(object):
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head(mask_feats)
# convert to numpy array to save memory
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
......
......@@ -15,6 +15,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
def __init__(self,
backbone,
neck=None,
shared_head=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
......@@ -28,8 +29,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
if neck is not None:
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
if shared_head is not None:
self.shared_head = builder.build_shared_head(shared_head)
if rpn_head is not None:
self.rpn_head = builder.build_head(rpn_head)
......@@ -40,8 +42,13 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
if mask_roi_extractor is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.share_roi_extractor = False
else:
self.share_roi_extractor = True
self.mask_roi_extractor = self.bbox_roi_extractor
self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg
......@@ -62,14 +69,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
m.init_weights()
else:
self.neck.init_weights()
if self.with_shared_head:
self.shared_head.init_weights(pretrained=pretrained)
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_bbox:
self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights()
if self.with_mask:
self.mask_roi_extractor.init_weights()
self.mask_head.init_weights()
if not self.share_roi_extractor:
self.mask_roi_extractor.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
......@@ -130,6 +140,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# TODO: a more flexible way to decide which feature maps to use
bbox_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = self.bbox_head(bbox_feats)
bbox_targets = self.bbox_head.get_target(
......@@ -140,9 +152,29 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# mask head forward and loss
if self.with_mask:
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois)
if not self.share_roi_extractor:
pos_rois = bbox2roi(
[res.pos_bboxes for res in sampling_results])
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
else:
pos_inds = []
device = bbox_feats.device
for res in sampling_results:
pos_inds.append(
torch.ones(
res.pos_bboxes.shape[0],
device=device,
dtype=torch.uint8))
pos_inds.append(
torch.zeros(
res.neg_bboxes.shape[0],
device=device,
dtype=torch.uint8))
pos_inds = torch.cat(pos_inds)
mask_feats = bbox_feats[pos_inds]
mask_pred = self.mask_head(mask_feats)
mask_targets = self.mask_head.get_target(
......
......@@ -53,11 +53,13 @@ class FCNMaskHead(nn.Module):
padding=padding,
normalize=normalize,
bias=self.with_bias))
upsample_in_channels = (self.conv_out_channels
if self.num_convs > 0 else in_channels)
if self.upsample_method is None:
self.upsample = None
elif self.upsample_method == 'deconv':
self.upsample = nn.ConvTranspose2d(
self.conv_out_channels,
upsample_in_channels,
self.conv_out_channels,
self.upsample_ratio,
stride=self.upsample_ratio)
......@@ -66,7 +68,10 @@ class FCNMaskHead(nn.Module):
scale_factor=self.upsample_ratio, mode=self.upsample_method)
out_channels = 1 if self.class_agnostic else self.num_classes
self.conv_logits = nn.Conv2d(self.conv_out_channels, out_channels, 1)
logits_in_channel = (self.conv_out_channels
if self.upsample_method == 'deconv' else
upsample_in_channels)
self.conv_logits = nn.Conv2d(logits_in_channel, out_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.debug_imgs = None
......
......@@ -22,9 +22,8 @@ class Registry(object):
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
module_class))
raise TypeError('module must be a child of nn.Module, but got {}'.
format(module_class))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
......@@ -39,5 +38,6 @@ class Registry(object):
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
DETECTORS = Registry('detector')
from .res_layer import ResLayer
__all__ = ['ResLayer']
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment