diff --git a/configs/swin/mask_rcnn_swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py b/configs/swin/mask_rcnn_swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a7d5e4982c4f57124708525dae9494fd2d74a8 --- /dev/null +++ b/configs/swin/mask_rcnn_swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py @@ -0,0 +1,80 @@ +_base_ = [ + '../_base_/models/mask_rcnn_swin_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + ape=False, + drop_path_rate=0.3, + patch_norm=True, + use_checkpoint=False + ), + neck=dict(in_channels=[128, 256, 512, 1024])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='AutoAugment', + policies=[ + [ + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict(type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ] + ]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) +lr_config = dict(step=[27, 33]) +runner = dict(type='EpochBasedRunnerAmp', max_epochs=36) + +# do not use mmdet version fp16 +fp16 = None +optimizer_config = dict( + type="DistOptimizerHook", + update_interval=1, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + use_fp16=True, +) diff --git a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco_fixhead.py b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco_fixhead.py new file mode 100644 index 0000000000000000000000000000000000000000..ef90de7e07ac5102ea493b85fa9bc90c404e1ad6 --- /dev/null +++ b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco_fixhead.py @@ -0,0 +1,83 @@ +_base_ = [ + '../_base_/models/mask_rcnn_swin_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + ape=False, + drop_path_rate=0.1, + patch_norm=True, + use_checkpoint=False + ), + neck=dict(in_channels=[96, 192, 384, 768])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='AutoAugment', + policies=[ + [ + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict(type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ] + ]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'neck': dict(lr_mult=0., decay_mult=0.), + 'head': dict(lr_mult=0., decay_mult=0.)})) +lr_config = dict(step=[8, 11]) +#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) +runner = dict(type='EpochBasedRunner', max_epochs=12) + +# do not use mmdet version fp16 +# fp16 = None +# optimizer_config = dict( +# type="DistOptimizerHook", +# update_interval=1, +# grad_clip=None, +# coalesce=True, +# bucket_size_mb=-1, +# use_fp16=True, +# ) diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py index 6eaa3ac963a9c96828427f42bfe1ed8c0fc8f001..7510d9e13753604b1312bc499467a8e959aec56b 100644 --- a/tools/gen_checkpoint.py +++ b/tools/gen_checkpoint.py @@ -9,6 +9,7 @@ def parse_args(): description='generate model') parser.add_argument('--backbone', help='the backbone checkpoint file') parser.add_argument('--head', help='the head checkpoint file') + parser.add_argument('--new-backbone', help='the trained checkpoint file') parser.add_argument('--out', help='output result file in pickle format') args = parser.parse_args() return args @@ -21,31 +22,45 @@ def get_sd(filename, return_sd=True): else: return ck -def merge(target, backbone, head): +def merge(backbone, head): + target = dict() + target['state_dict'] = dict() tsd = target['state_dict'] - bsd = target['state_dict'] - hsd = target['state_dict'] - for key in tsd.keys(): + bsd = backbone['state_dict'] + hsd = head['state_dict'] + + for key in bsd.keys(): if 'backbone' in key: - assert key in bsd tsd[key] = bsd[key] - else: - assert key in hsd + for key in hsd.keys(): + if 'backbone' not in key: tsd[key] = hsd[key] return target +def gen_backbone(backbone, new_backbone): + target = backbone.copy() + tsd = target['model'] + nbsd = new_backbone['state_dict'] + for key in tsd.keys(): + nk = 'backbone.{}'.format(key) + if nk not in nbsd: + print("{} not find".format(key)) + continue + tsd[key] = nbsd[nk] + return target + def main(): args = parse_args() print("generate checkpoint") backbone = get_sd(args.backbone, return_sd=False) - head = get_sd(args.head, return_sd=False) - - target = backbone.copy() - #target = head.copy() - - target = merge(target, backbone, head) - os.makedirs(os.path.basename(args.out), exist_ok=True) + if args.head: + head = get_sd(args.head, return_sd=False) + target = merge(backbone, head) + elif args.new_backbone: + nb = get_sd(args.new_backbone, return_sd=False) + target = gen_backbone(backbone, nb) + #os.makedirs(os.path.basename(args.out), exist_ok=True) torch.save(target, args.out) print("saved checkpoint in {}".format(args.out))