diff --git a/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py b/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..2676c37634957f9d82c75925cd05ea5f2486e55b
--- /dev/null
+++ b/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py
@@ -0,0 +1,55 @@
+_base_ = [
+    '../../_base_/models/faster_rcnn_r50_fpn.py',
+    '../../_base_/datasets/coco_detection.py',
+    '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
+]
+
+# model settings
+find_unused_parameters=True
+weight=1
+distiller = dict(
+    type='BackboneDistiller',
+    teacher_pretrained = '/data/wanggh/project/pytorch/Swin-Transformer-Object-Detection/work_dirs/faster_rcnn_r152_fpn_1x_coco/latest.pth',
+    init_student = 'neck_head',
+    train_head = False,
+    distill_cfg = [ dict(feature_level = 0,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_0',
+                                       student_channels = 256,
+                                       teacher_channels = 256,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 1,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_1',
+                                       student_channels = 512,
+                                       teacher_channels = 512,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 2,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_2',
+                                       student_channels = 1024,
+                                       teacher_channels = 1024,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 3,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_3',
+                                       student_channels = 2048,
+                                       teacher_channels = 2048,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                   ]
+    )
+
+student_cfg = 'configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py'
+teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
diff --git a/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py b/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py
index 104c815351dd3e37fd634c65c6b6d85f28afa0f2..b53c2b2389af73b80ba241223508b8be8827f5bb 100644
--- a/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py
+++ b/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py
@@ -10,7 +10,7 @@ weight=1
 distiller = dict(
     type='BackboneDistiller',
     teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
-    init_student = True,
+    init_student = '',
     distill_cfg = [ dict(feature_level = 0,
                          methods=[dict(type='MSELoss',
                                        name='loss_mb_0',
diff --git a/configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py b/configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..58269670a8e99e6c79c8ab5e3d13448251ba8890
--- /dev/null
+++ b/configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py
@@ -0,0 +1,6 @@
+_base_ = './faster_rcnn_r50_fpn_1x_coco.py'
+model = dict(
+    backbone=dict(
+        depth=152,
+        init_cfg=dict(type='Pretrained',
+                      checkpoint='torchvision://resnet152')))
diff --git a/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py b/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
index 009bd93d06b3284c7b31f33f82d636f774e86b74..6517e7a1ee1fadb7f3569b47038a18e75fa8a327 100644
--- a/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
+++ b/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
@@ -2,4 +2,4 @@ _base_ = [
     '../_base_/models/faster_rcnn_r50_fpn.py',
     '../_base_/datasets/coco_detection.py',
     '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
-]
+]
\ No newline at end of file
diff --git a/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py b/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py
index 86278d8a074ede03726f78c9112a97bb17bb7f57..062f3c1a838bfa7340aa6090cfc16e5bfdc85f9d 100644
--- a/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py
+++ b/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py
@@ -12,6 +12,9 @@ model = dict(
 optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
 #optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
 optimizer_config = dict(grad_clip=None)
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4)
 # learning policy
 # actual epoch = 3 * 3 = 9
 lr_config = dict(
diff --git a/mmdet/distillation/distillers/backbone_distiller.py b/mmdet/distillation/distillers/backbone_distiller.py
index c14b1c4eea75226ed24ed91897372fc24cfbf1ed..f6ff660af56f04755235702af1fe2c8569b66f41 100644
--- a/mmdet/distillation/distillers/backbone_distiller.py
+++ b/mmdet/distillation/distillers/backbone_distiller.py
@@ -38,9 +38,9 @@ class BackboneDistiller(BaseDetector):
         if init_student:
             assert init_student in ['neck', 'head', 'neck_head']
             def check_key(key, init_student):
-                if key.startswith('neck.') and 'neck' in init_student:
+                if 'neck' in key and 'neck' in init_student:
                     return True
-                elif key.startswith('head.') and 'head' in init_student:
+                elif 'head' in key and 'head' in init_student:
                     return True
                 else:
                     return False
diff --git a/submit_work.sh b/submit_work.sh
index 9e35e75f75813a0d128b28013f1dd267cb09db28..4bfde547f981c45429ac1f90d3ae87d185204a71 100644
--- a/submit_work.sh
+++ b/submit_work.sh
@@ -11,5 +11,6 @@ while [[ $num > 0 ]]; do
 done
 sleep 2
 # when $pid finished, run these 
-PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
-#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH2_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
\ No newline at end of file
+#PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
+#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py 4
+PORT=29502 tools/dist_train.sh configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py 8
\ No newline at end of file
diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py
index 59f138dc2db645d5ce0a66b92a804872843cc315..b3ba6d6441f9f28c6a16510471a12fd7d201ed09 100644
--- a/tools/gen_checkpoint.py
+++ b/tools/gen_checkpoint.py
@@ -8,9 +8,9 @@ def parse_args():
     parser = argparse.ArgumentParser(
         description='generate model')
     parser.add_argument('--backbone', help='the backbone checkpoint file')
-    parser.add_argument('--backbone-neck', help='the backbone-neck checkpoint file')
+    parser.add_argument('--neck', help='the neck checkpoint file')
     parser.add_argument('--head', help='the head checkpoint file')
-    parser.add_argument('--new-backbone', help='the trained checkpoint file')
+    parser.add_argument('--distill', help='the distilled model checkpoint file')
     parser.add_argument('--out', help='output result file in pickle format')
     args = parser.parse_args()
     return args
@@ -23,87 +23,47 @@ def get_sd(filename, return_sd=True):
     else:
         return ck
 
-def merge(backbone, head):
+def merge(backbone, neck, head):
     target = dict()
     target['state_dict'] = dict()
     tsd = target['state_dict']
     bsd = backbone['state_dict']
+    nsd = backbone['state_dict']
     hsd = head['state_dict']
     for key in bsd.keys():
         if 'backbone' in key:
             tsd[key] = bsd[key]
-    for key in hsd.keys():
-        if 'backbone' not in key:
-            tsd[key] = hsd[key]
-    return target
-
-def merge_bn_h(backbone, head):
-    target = dict()
-    target['state_dict'] = dict()
-    tsd = target['state_dict']
-    bsd = backbone['state_dict']
-    hsd = head['state_dict']
-
-    for key in bsd.keys():
-        if 'backbone' in key or 'neck' in key:
-            tsd[key] = bsd[key]
-        else:
-            assert 'head' in key
+    for key in nsd.keys():
+        if 'neck' in key:
+            tsd[key] = nsd[key]
     for key in hsd.keys():
         if 'head' in key:
             tsd[key] = hsd[key]
     return target
 
-def gen_backbone(backbone, new_backbone):
-    target = backbone.copy()
-    tsd = target['model']
-    nbsd = new_backbone['state_dict']
-    for key in tsd.keys():
-        nk = 'backbone.{}'.format(key)
-        if nk not in nbsd:
-            print("{} not find".format(key))
-            continue
-        tsd[key] = nbsd[nk]
-    return target
-
-def gen_imagenet_h(backbone, head):
+def gen_student(distill):
     target = dict()
     target['state_dict'] = dict()
     tsd = target['state_dict']
-    bsd = backbone['model']
-    hsd = head['state_dict']
-    for key in hsd.keys():
-        if 'backbone' not in key:
-            tsd[key] = hsd[key]
-        else:
-            bkey = key[9:]
-            if bkey not in bsd:
-                print("{} not load".format(key))
-                continue
-            tsd[key] = bsd[bkey]
+    distill_sd = distill['state_dict']
+    for key in distill_sd.keys():
+        if key.startswith('student.'):
+            tsd[key[8:]] = distill_sd[key]
     return target
 
+
 def main():
     args = parse_args()
     print("generate checkpoint")
 
-    if args.backbone and args.head:
+    if args.distill:
+        distill = get_sd(args.distill, return_sd=False)
+        target = gen_student(distill)
+    else:
         backbone = get_sd(args.backbone, return_sd=False)
+        neck = get_sd(args.neck, return_sd=False)
         head = get_sd(args.head, return_sd=False)
-        target = merge(backbone, head)
-    elif args.backbone_neck and args.head:
-        backbone = get_sd(args.backbone_neck, return_sd=False)
-        head = get_sd(args.head, return_sd=False)
-        print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head))
-        target = merge_bn_h(backbone, head)
-    elif args.new_backbone and args.head:
-        backbone = get_sd(args.new_backbone, return_sd=False)
-        head = get_sd(args.head, return_sd=False)
-        target = gen_imagenet_h(backbone, head)
-    elif args.new_backbone:
-        backbone = get_sd(args.backbone, return_sd=False)
-        nb = get_sd(args.new_backbone, return_sd=False)
-        target = gen_backbone(backbone, nb)
+        target = merge(backbone, neck, head)
     #os.makedirs(os.path.basename(args.out), exist_ok=True)
     torch.save(target, args.out)
     print("saved checkpoint in {}".format(args.out))