diff --git a/configs/swin/mask_rcnn_swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py b/configs/swin/mask_rcnn_swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a7d5e4982c4f57124708525dae9494fd2d74a8
--- /dev/null
+++ b/configs/swin/mask_rcnn_swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py
@@ -0,0 +1,80 @@
+_base_ = [
+    '../_base_/models/mask_rcnn_swin_fpn.py',
+    '../_base_/datasets/coco_instance.py',
+    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+    backbone=dict(
+        embed_dim=128,
+        depths=[2, 2, 18, 2],
+        num_heads=[4, 8, 16, 32],
+        window_size=7,
+        ape=False,
+        drop_path_rate=0.3,
+        patch_norm=True,
+        use_checkpoint=False
+    ),
+    neck=dict(in_channels=[128, 256, 512, 1024]))
+
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='AutoAugment',
+         policies=[
+             [
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+                                 (736, 1333), (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True)
+             ],
+             [
+                 dict(type='Resize',
+                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True),
+                 dict(type='RandomCrop',
+                      crop_type='absolute_range',
+                      crop_size=(384, 600),
+                      allow_negative_crop=True),
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333),
+                                 (576, 1333), (608, 1333), (640, 1333),
+                                 (672, 1333), (704, 1333), (736, 1333),
+                                 (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      override=True,
+                      keep_ratio=True)
+             ]
+         ]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[27, 33])
+runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
+
+# do not use mmdet version fp16
+fp16 = None
+optimizer_config = dict(
+    type="DistOptimizerHook",
+    update_interval=1,
+    grad_clip=None,
+    coalesce=True,
+    bucket_size_mb=-1,
+    use_fp16=True,
+)
diff --git a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco_fixhead.py b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco_fixhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef90de7e07ac5102ea493b85fa9bc90c404e1ad6
--- /dev/null
+++ b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco_fixhead.py
@@ -0,0 +1,83 @@
+_base_ = [
+    '../_base_/models/mask_rcnn_swin_fpn.py',
+    '../_base_/datasets/coco_instance.py',
+    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+    backbone=dict(
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        ape=False,
+        drop_path_rate=0.1,
+        patch_norm=True,
+        use_checkpoint=False
+    ),
+    neck=dict(in_channels=[96, 192, 384, 768]))
+
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='AutoAugment',
+         policies=[
+             [
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+                                 (736, 1333), (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True)
+             ],
+             [
+                 dict(type='Resize',
+                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True),
+                 dict(type='RandomCrop',
+                      crop_type='absolute_range',
+                      crop_size=(384, 600),
+                      allow_negative_crop=True),
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333),
+                                 (576, 1333), (608, 1333), (640, 1333),
+                                 (672, 1333), (704, 1333), (736, 1333),
+                                 (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      override=True,
+                      keep_ratio=True)
+             ]
+         ]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.),
+                                                 'neck': dict(lr_mult=0., decay_mult=0.),
+                                                 'head': dict(lr_mult=0., decay_mult=0.)}))
+lr_config = dict(step=[8, 11])
+#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+runner = dict(type='EpochBasedRunner', max_epochs=12)
+ 
+# do not use mmdet version fp16
+# fp16 = None
+# optimizer_config = dict(
+#     type="DistOptimizerHook",
+#     update_interval=1,
+#     grad_clip=None,
+#     coalesce=True,
+#     bucket_size_mb=-1,
+#     use_fp16=True,
+# )
diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py
index 6eaa3ac963a9c96828427f42bfe1ed8c0fc8f001..7510d9e13753604b1312bc499467a8e959aec56b 100644
--- a/tools/gen_checkpoint.py
+++ b/tools/gen_checkpoint.py
@@ -9,6 +9,7 @@ def parse_args():
         description='generate model')
     parser.add_argument('--backbone', help='the backbone checkpoint file')
     parser.add_argument('--head', help='the head checkpoint file')
+    parser.add_argument('--new-backbone', help='the trained checkpoint file')
     parser.add_argument('--out', help='output result file in pickle format')
     args = parser.parse_args()
     return args
@@ -21,31 +22,45 @@ def get_sd(filename, return_sd=True):
     else:
         return ck
 
-def merge(target, backbone, head):
+def merge(backbone, head):
+    target = dict()
+    target['state_dict'] = dict()
     tsd = target['state_dict']
-    bsd = target['state_dict']
-    hsd = target['state_dict']
-    for key in tsd.keys():
+    bsd = backbone['state_dict']
+    hsd = head['state_dict']
+
+    for key in bsd.keys():
         if 'backbone' in key:
-            assert key in bsd
             tsd[key] = bsd[key]
-        else:
-            assert key in hsd
+    for key in hsd.keys():
+        if 'backbone' not in key:
             tsd[key] = hsd[key]
     return target
 
+def gen_backbone(backbone, new_backbone):
+    target = backbone.copy()
+    tsd = target['model']
+    nbsd = new_backbone['state_dict']
+    for key in tsd.keys():
+        nk = 'backbone.{}'.format(key)
+        if nk not in nbsd:
+            print("{} not find".format(key))
+            continue
+        tsd[key] = nbsd[nk]
+    return target
+
 def main():
     args = parse_args()
     print("generate checkpoint")
 
     backbone = get_sd(args.backbone, return_sd=False)
-    head = get_sd(args.head, return_sd=False)
-
-    target = backbone.copy()
-    #target = head.copy()
-
-    target = merge(target, backbone, head)
-    os.makedirs(os.path.basename(args.out), exist_ok=True)
+    if args.head:
+        head = get_sd(args.head, return_sd=False)
+        target = merge(backbone, head)
+    elif args.new_backbone:
+        nb = get_sd(args.new_backbone, return_sd=False)
+        target = gen_backbone(backbone, nb)
+    #os.makedirs(os.path.basename(args.out), exist_ok=True)
     torch.save(target, args.out)
     print("saved checkpoint in {}".format(args.out))