Skip to content
Snippets Groups Projects
Commit 3653775f authored by wanggh's avatar wanggh
Browse files

Merge branch 'master' of git.nju.edu.cn:wanggh/Swin-Transformer-Object-Detection

parents a002b205 9e4746ea
No related branches found
No related tags found
No related merge requests found
_base_ = [
'../_base_/models/mask_rcnn_swin_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
backbone=dict(
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
ape=False,
drop_path_rate=0.3,
patch_norm=True,
use_checkpoint=False
),
neck=dict(in_channels=[128, 256, 512, 1024]))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[27, 33])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
_base_ = [
'../_base_/models/mask_rcnn_swin_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
backbone=dict(
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=0.1,
patch_norm=True,
use_checkpoint=False
),
neck=dict(in_channels=[96, 192, 384, 768]))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'neck': dict(lr_mult=0., decay_mult=0.),
'head': dict(lr_mult=0., decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
...@@ -9,6 +9,7 @@ def parse_args(): ...@@ -9,6 +9,7 @@ def parse_args():
description='generate model') description='generate model')
parser.add_argument('--backbone', help='the backbone checkpoint file') parser.add_argument('--backbone', help='the backbone checkpoint file')
parser.add_argument('--head', help='the head checkpoint file') parser.add_argument('--head', help='the head checkpoint file')
parser.add_argument('--new-backbone', help='the trained checkpoint file')
parser.add_argument('--out', help='output result file in pickle format') parser.add_argument('--out', help='output result file in pickle format')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -21,30 +22,43 @@ def get_sd(filename, return_sd=True): ...@@ -21,30 +22,43 @@ def get_sd(filename, return_sd=True):
else: else:
return ck return ck
def merge(target, backbone, head): def merge(backbone, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict'] tsd = target['state_dict']
bsd = backbone['state_dict'] bsd = backbone['state_dict']
hsd = head['state_dict'] hsd = head['state_dict']
for key in tsd.keys(): for key in bsd.keys():
if 'backbone' in key: if 'backbone' in key:
assert key in bsd
tsd[key] = bsd[key] tsd[key] = bsd[key]
else: for key in hsd.keys():
assert key in hsd if 'backbone' not in key:
tsd[key] = hsd[key] tsd[key] = hsd[key]
return target return target
def gen_backbone(backbone, new_backbone):
target = backbone.copy()
tsd = target['model']
nbsd = new_backbone['state_dict']
for key in tsd.keys():
nk = 'backbone.{}'.format(key)
if nk not in nbsd:
print("{} not find".format(key))
continue
tsd[key] = nbsd[nk]
return target
def main(): def main():
args = parse_args() args = parse_args()
print("generate checkpoint") print("generate checkpoint")
backbone = get_sd(args.backbone, return_sd=False) backbone = get_sd(args.backbone, return_sd=False)
head = get_sd(args.head, return_sd=False) if args.head:
head = get_sd(args.head, return_sd=False)
#target = backbone.copy() target = merge(backbone, head)
target = head.copy() elif args.new_backbone:
nb = get_sd(args.new_backbone, return_sd=False)
target = merge(target, backbone, head) target = gen_backbone(backbone, nb)
#os.makedirs(os.path.basename(args.out), exist_ok=True) #os.makedirs(os.path.basename(args.out), exist_ok=True)
torch.save(target, args.out) torch.save(target, args.out)
print("saved checkpoint in {}".format(args.out)) print("saved checkpoint in {}".format(args.out))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment