diff --git a/.gitignore b/.gitignore
index 77ca0d7c808c77d27777041e64cd8a01054433fc..a5d7e1f95306a5f04087418485f85cbb7b0531b4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -114,6 +114,7 @@ data
 *.pkl.json
 *.log.json
 work_dirs/
+results/
 
 # Pytorch
 *.pth
diff --git a/configs/_base_/datasets/voc0712_analyze.py b/configs/_base_/datasets/voc0712_analyze.py
new file mode 100644
index 0000000000000000000000000000000000000000..873ecf5857ff7d038d76cf1f70c73f884eb27deb
--- /dev/null
+++ b/configs/_base_/datasets/voc0712_analyze.py
@@ -0,0 +1,53 @@
+# dataset settings
+dataset_type = 'VOCDataset'
+#data_root = 'data/VOCdevkit/'
+data_root = '/opt/Dataset/VOCdevkit/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1000, 600),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    samples_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=[
+            data_root + 'VOC2007/ImageSets/Main/trainval.txt',
+            data_root + 'VOC2012/ImageSets/Main/trainval.txt'
+        ],
+        img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
+        img_prefix=data_root + 'VOC2007/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
+        img_prefix=data_root + 'VOC2007/',
+        pipeline=test_pipeline))
+evaluation = dict(interval=1, metric='mAP')
diff --git a/configs/distillers/feature_mimicking/fm_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_voc0712.py b/configs/distillers/feature_mimicking/fm_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_voc0712.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbab6bf72302f596287442db686754d04c99bb21
--- /dev/null
+++ b/configs/distillers/feature_mimicking/fm_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_voc0712.py
@@ -0,0 +1,44 @@
+_base_ = [
+    '../../_base_/models/faster_rcnn_r50_fpn.py',
+    '../../_base_/datasets/voc0712.py',
+    '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
+]
+
+# model settings
+find_unused_parameters=True
+weight=1
+distiller = dict(
+    type='FeatureMimicking',
+    teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/faster_rcnn/faster_rcnn_r152_fpn_1x_voc0712.pth',
+    init_student = 'neck_head',
+    distill_cfg = [ dict(student_module = 'roi_head.bbox_head.shared_fcs.0',
+                         teacher_module = 'roi_head.bbox_head.shared_fcs.0',
+                         output_hook = True,
+                         methods=[dict(type='MSELoss',
+                                       name='fm_loss',
+                                       student_channels = 1024,
+                                       teacher_channels = 1024,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                   ]
+    )
+
+student_cfg = 'configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py'
+teacher_cfg = 'configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py'
+
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+# actual epoch = 3 * 3 = 9
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=200,
+    warmup_ratio=0.001,
+    step=[3])
+# runtime settings
+runner = dict(
+    type='EpochBasedRunner', max_epochs=4)  # actual epoch = 4 * 3 = 12
\ No newline at end of file
diff --git a/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_voc0712.py b/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_voc0712.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a798f1050f2ab3cd61e598005fc073ac59c1f01
--- /dev/null
+++ b/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_voc0712.py
@@ -0,0 +1,70 @@
+_base_ = [
+    '../../_base_/models/faster_rcnn_r50_fpn.py',
+    '../../_base_/datasets/voc0712.py',
+    '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
+]
+
+# model settings
+find_unused_parameters=True
+weight=1
+distiller = dict(
+    type='BackboneDistiller',
+    teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/faster_rcnn/faster_rcnn_r152_fpn_1x_voc0712.pth',
+    init_student = 'neck_head',
+    train_head = False,
+    distill_cfg = [ dict(feature_level = 0,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_0',
+                                       student_channels = 256,
+                                       teacher_channels = 256,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 1,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_1',
+                                       student_channels = 512,
+                                       teacher_channels = 512,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 2,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_2',
+                                       student_channels = 1024,
+                                       teacher_channels = 1024,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 3,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_3',
+                                       student_channels = 2048,
+                                       teacher_channels = 2048,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                   ]
+    )
+
+student_cfg = 'configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py'
+teacher_cfg = 'configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py'
+
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+# actual epoch = 3 * 3 = 9
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=200,
+    warmup_ratio=0.001,
+    step=[3])
+# runtime settings
+runner = dict(
+    type='EpochBasedRunner', max_epochs=4)  # actual epoch = 4 * 3 = 12
\ No newline at end of file
diff --git a/configs/pascal_voc/cascade_rcnn_r152_fpn_1x_voc0712.py b/configs/pascal_voc/cascade_rcnn_r152_fpn_1x_voc0712.py
new file mode 100644
index 0000000000000000000000000000000000000000..0576486bb66bbc9cfd61c8e5d46f0b78a19bbca1
--- /dev/null
+++ b/configs/pascal_voc/cascade_rcnn_r152_fpn_1x_voc0712.py
@@ -0,0 +1,7 @@
+_base_ = './cascade_rcnn_r50_fpn_1x_voc0712.py'
+model = dict(
+    backbone=dict(
+        depth=152,
+        init_cfg=dict(type='Pretrained',
+                      checkpoint='torchvision://resnet152')))
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
\ No newline at end of file
diff --git a/configs/pascal_voc/cascade_rcnn_r50_fpn_1x_voc0712.py b/configs/pascal_voc/cascade_rcnn_r50_fpn_1x_voc0712.py
index c516ea7a3e6b812b1d204dd90c01c5517ddbb6b4..7a582d484aac6a113f74e36a1180a705979389c0 100644
--- a/configs/pascal_voc/cascade_rcnn_r50_fpn_1x_voc0712.py
+++ b/configs/pascal_voc/cascade_rcnn_r50_fpn_1x_voc0712.py
@@ -58,11 +58,16 @@ model = dict(roi_head=dict(bbox_head=[
         ]))
 
 # optimizer
-optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
 optimizer_config = dict(grad_clip=None)
 # learning policy
 # actual epoch = 3 * 3 = 9
-lr_config = dict(policy='step', step=[3])
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=200,
+    warmup_ratio=0.001,
+    step=[3])
 # runtime settings
 runner = dict(
     type='EpochBasedRunner', max_epochs=4)  # actual epoch = 4 * 3 = 12
diff --git a/configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py b/configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py
index d50b6594a7495615a98566adbe7097929bbe01da..2d6a6016c173c88972fcf4d4ff286f5d9f030cc0 100644
--- a/configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py
+++ b/configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py
@@ -6,4 +6,10 @@ model = dict(
                       checkpoint='torchvision://resnet152')))
 
 optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=200,
+    warmup_ratio=0.001,
+    step=[3])
 
diff --git a/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py b/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
index 7866acebea689e7a863a836c326b1407de733fe8..7ab15a6469056f8cc909f193b71e4772636802ce 100644
--- a/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
+++ b/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
@@ -8,7 +8,12 @@ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
 optimizer_config = dict(grad_clip=None)
 # learning policy
 # actual epoch = 3 * 3 = 9
-lr_config = dict(policy='step', step=[3])
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=200,
+    warmup_ratio=0.001,
+    step=[3])
 # runtime settings
 runner = dict(
     type='EpochBasedRunner', max_epochs=4)  # actual epoch = 4 * 3 = 12
diff --git a/configs/rpn/rpn_r152_fpn_1x_coco.py b/configs/rpn/rpn_r152_fpn_1x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb1a9e608c467c633c1d1945389f8c68d96d8d47
--- /dev/null
+++ b/configs/rpn/rpn_r152_fpn_1x_coco.py
@@ -0,0 +1,6 @@
+_base_ = './rpn_r50_fpn_1x_coco.py'
+model = dict(
+    backbone=dict(
+        depth=152,
+        init_cfg=dict(type='Pretrained',
+                      checkpoint='torchvision://resnet152')))
diff --git a/mmdet/distillation/distillers/__init__.py b/mmdet/distillation/distillers/__init__.py
index 4dc0dd0dc8f8fef9e029c5aef3dc80f038ad9b0e..ee7a3ecda1cef71d79eb5c1f738c29be898e811e 100644
--- a/mmdet/distillation/distillers/__init__.py
+++ b/mmdet/distillation/distillers/__init__.py
@@ -2,10 +2,12 @@ from .detection_distiller import DetectionDistiller
 from .backbone_distiller import BackboneDistiller
 from .fpn_distiller import FPNDistiller
 from .head_distiller import HeadDistiller
+from .feature_mimicking import FeatureMimicking
 
 __all__ = [
     'DetectionDistiller',
     'BackboneDistiller',
     'FPNDistiller',
-    'HeadDistiller'
+    'HeadDistiller',
+    'FeatureMimicking',
 ]
\ No newline at end of file
diff --git a/mmdet/distillation/distillers/feature_mimicking.py b/mmdet/distillation/distillers/feature_mimicking.py
new file mode 100644
index 0000000000000000000000000000000000000000..32bc9dac388eaf2c993a2a814c53069166b79ba0
--- /dev/null
+++ b/mmdet/distillation/distillers/feature_mimicking.py
@@ -0,0 +1,201 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+from mmdet.models.detectors.base import BaseDetector
+from mmdet.models import build_detector
+from mmcv.runner import  load_checkpoint, _load_checkpoint, load_state_dict
+from ..builder import DISTILLER,build_distill_loss
+from collections import OrderedDict
+
+
+
+
+@DISTILLER.register_module()
+class FeatureMimicking(BaseDetector):
+    """Feature mimicking for detectors.
+    It typically consists of teacher_model and student_model.
+    """
+    def __init__(self,
+                 teacher_cfg,
+                 student_cfg,
+                 distill_cfg=None,
+                 teacher_pretrained=None,
+                 init_student=''):
+
+        super(FeatureMimicking, self).__init__()
+        
+        self.teacher = build_detector(teacher_cfg.model,
+                                        train_cfg=teacher_cfg.get('train_cfg'),
+                                        test_cfg=teacher_cfg.get('test_cfg'))
+        self.init_weights_teacher(teacher_pretrained)
+        self.teacher.eval()
+
+        self.student= build_detector(student_cfg.model,
+                                        train_cfg=student_cfg.get('train_cfg'),
+                                        test_cfg=student_cfg.get('test_cfg'))
+        self.student.init_weights()
+        if init_student:
+            assert init_student in ['neck', 'head', 'neck_head']
+            def check_key(key, init_student):
+                if 'neck' in key and 'neck' in init_student:
+                    return True
+                elif 'head' in key and 'head' in init_student:
+                    return True
+                else:
+                    return False
+
+            t_checkpoint = _load_checkpoint(teacher_pretrained, map_location='cpu')
+            all_name = []
+            for name, v in t_checkpoint["state_dict"].items():
+                if check_key(name, init_student):
+                    all_name.append((name, v))
+
+            state_dict = OrderedDict(all_name)
+            load_state_dict(self.student, state_dict)
+
+        self.distill_losses = nn.ModuleDict()
+        self.distill_cfg = distill_cfg
+        student_modules = dict(self.student.named_modules())
+        teacher_modules = dict(self.teacher.named_modules())
+        def regitster_hooks(student_module,teacher_module):
+            def hook_teacher_forward(module, input, output):
+                
+                    self.register_buffer(teacher_module,output)
+                
+            def hook_student_forward(module, input, output):
+
+                    self.register_buffer( student_module,output )
+            return hook_teacher_forward,hook_student_forward
+        
+        for item_loc in distill_cfg:
+            
+            student_module = 'student_' + item_loc.student_module.replace('.','_')
+            teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')
+
+            self.register_buffer(student_module,None)
+            self.register_buffer(teacher_module,None)
+
+            hook_teacher_forward,hook_student_forward = regitster_hooks(student_module, teacher_module)
+            teacher_modules[item_loc.teacher_module].register_forward_hook(hook_teacher_forward)
+            student_modules[item_loc.student_module].register_forward_hook(hook_student_forward)
+
+            for item_loss in item_loc.methods:
+                loss_name = item_loss.name
+                self.distill_losses[loss_name] = build_distill_loss(item_loss)
+
+
+    def base_parameters(self):
+        return nn.ModuleList([self.student, self.distill_losses])
+        
+    def discriminator_parameters(self):
+        return self.discriminator
+
+    @property
+    def with_neck(self):
+        """bool: whether the detector has a neck"""
+        return hasattr(self.student, 'neck') and self.student.neck is not None
+
+    # TODO: these properties need to be carefully handled
+    # for both single stage & two stage detectors
+    @property
+    def with_shared_head(self):
+        """bool: whether the detector has a shared head in the RoI Head"""
+        return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
+
+    @property
+    def with_bbox(self):
+        """bool: whether the detector has a bbox head"""
+        return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
+                or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
+
+    @property
+    def with_mask(self):
+        """bool: whether the detector has a mask head"""
+        return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
+                or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
+
+    def init_weights_teacher(self, path=None):
+        """Load the pretrained model in teacher detector.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
+
+
+    def forward_train(self,
+                      img,
+                      img_metas,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      proposals=None,
+                      **kwargs):
+        """
+        Args:
+            img (Tensor): Input images of shape (N, C, H, W).
+                Typically these should be mean centered and std scaled.
+            img_metas (list[dict]): A List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                :class:`mmdet.datasets.pipelines.Collect`.
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
+        """
+        with torch.no_grad():
+            self.teacher.eval()
+            f_t = self.teacher.extract_feat(img)
+
+
+        f_s = self.student.extract_feat(img)
+        losses = dict()
+
+        proposal_cfg = self.student.train_cfg.get('rpn_proposal',
+                                            self.student.test_cfg.rpn)
+        s_rpn_losses, s_proposal_list = self.student.rpn_head.forward_train(
+            f_s,
+            img_metas,
+            gt_bboxes,
+            gt_labels=None,
+            gt_bboxes_ignore=gt_bboxes_ignore,
+            proposal_cfg=proposal_cfg)
+        losses.update(s_rpn_losses)
+
+        s_roi_losses = self.student.roi_head.forward_train(f_s, img_metas, s_proposal_list,
+                                                 gt_bboxes, gt_labels,
+                                                 gt_bboxes_ignore, gt_masks,
+                                                 **kwargs)
+        losses.update(s_roi_losses)
+
+        with torch.no_grad():
+            t_roi_with_ps_losses = self.teacher.roi_head.forward_train(f_t, img_metas, s_proposal_list,
+                                                 gt_bboxes, gt_labels,
+                                                 gt_bboxes_ignore, gt_masks,
+                                                 **kwargs)
+
+        buffer_dict = dict(self.named_buffers())
+        for item_loc in self.distill_cfg:
+            
+            student_module = 'student_' + item_loc.student_module.replace('.','_')
+            teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')
+            
+            student_feat = buffer_dict[student_module]
+            teacher_feat = buffer_dict[teacher_module]
+
+            for item_loss in item_loc.methods:
+                loss_name = item_loss.name
+                losses[loss_name] = self.distill_losses[loss_name](student_feat, teacher_feat)
+        return losses
+
+    
+    def simple_test(self, img, img_metas, **kwargs):
+        return self.student.simple_test(img, img_metas, **kwargs)
+
+    def aug_test(self, imgs, img_metas, **kwargs):
+        return self.student.aug_test(img, img_metas, **kwargs)
+
+    def extract_feat(self, imgs):
+        """Extract features from images."""
+        return self.student.extract_feat(imgs)
\ No newline at end of file