diff --git a/configs/distillers/mimic_head/mb_mask_rcnn_swinT_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py b/configs/distillers/mimic_head/mb_mask_rcnn_swinT_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e897355e4c43c803911489ec8a0b2204d756369
--- /dev/null
+++ b/configs/distillers/mimic_head/mb_mask_rcnn_swinT_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py
@@ -0,0 +1,77 @@
+_base_ = [
+    '../../_base_/models/mask_rcnn_swin_fpn.py',
+    '../../_base_/datasets/coco_instance.py',
+    '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
+]
+
+# model settings
+find_unused_parameters=True
+weight=1
+distiller = dict(
+    type='HeadDistiller',
+    teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_tiny_patch4_window7.pth',
+    init_student = True,
+    distill_cfg = [ dict(feature_level = 0,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_0',
+                                       student_channels = 96,
+                                       teacher_channels = 96,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 1,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_1',
+                                       student_channels = 192,
+                                       teacher_channels = 192,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 2,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_2',
+                                       student_channels = 384,
+                                       teacher_channels = 384,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                    dict(feature_level = 3,
+                         methods=[dict(type='MSELoss',
+                                       name='loss_mb_3',
+                                       student_channels = 768,
+                                       teacher_channels = 768,
+                                       weight = weight,
+                                       )
+                                ]
+                        ),
+                   ]
+    )
+
+student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
+teacher_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
+
+data = dict(
+    samples_per_gpu=2,
+    workers_per_gpu=2,)
+#data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[8, 11])
+#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+runner = dict(type='EpochBasedRunner', max_epochs=12)
+# do not use mmdet version fp16
+# fp16 = None
+# optimizer_config = dict(
+#     type="DistOptimizerHook",
+#     update_interval=1,
+#     grad_clip=None,
+#     coalesce=True,
+#     bucket_size_mb=-1,
+#     use_fp16=True,
+# )
\ No newline at end of file
diff --git a/mmdet/distillation/distillers/__init__.py b/mmdet/distillation/distillers/__init__.py
index a291ac1fb2cd5ff39b3aa34b6702024a2806e912..4dc0dd0dc8f8fef9e029c5aef3dc80f038ad9b0e 100644
--- a/mmdet/distillation/distillers/__init__.py
+++ b/mmdet/distillation/distillers/__init__.py
@@ -1,9 +1,11 @@
 from .detection_distiller import DetectionDistiller
 from .backbone_distiller import BackboneDistiller
 from .fpn_distiller import FPNDistiller
+from .head_distiller import HeadDistiller
 
 __all__ = [
     'DetectionDistiller',
     'BackboneDistiller',
-    'FPNDistiller'
+    'FPNDistiller',
+    'HeadDistiller'
 ]
\ No newline at end of file
diff --git a/mmdet/distillation/distillers/head_distiller.py b/mmdet/distillation/distillers/head_distiller.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da811209b625abfd10a1418a06cf329d8a0616e
--- /dev/null
+++ b/mmdet/distillation/distillers/head_distiller.py
@@ -0,0 +1,181 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+from mmdet.models.detectors.base import BaseDetector
+from mmdet.models import build_detector
+from mmcv.runner import  load_checkpoint, _load_checkpoint, load_state_dict
+from ..builder import DISTILLER,build_distill_loss
+from collections import OrderedDict
+
+
+
+
+@DISTILLER.register_module()
+class HeadDistiller(BaseDetector):
+    """Head distiller for detectors.
+    It typically consists of teacher_model and student_model.
+    """
+    def __init__(self,
+                 teacher_cfg,
+                 student_cfg,
+                 distill_cfg=None,
+                 teacher_pretrained=None,
+                 init_student=False):
+
+        super(HeadDistiller, self).__init__()
+        
+        self.teacher = build_detector(teacher_cfg.model,
+                                        train_cfg=teacher_cfg.get('train_cfg'),
+                                        test_cfg=teacher_cfg.get('test_cfg'))
+        self.init_weights_teacher(teacher_pretrained)
+
+        
+        self.teacher.eval()
+        self.student= build_detector(student_cfg.model,
+                                        train_cfg=student_cfg.get('train_cfg'),
+                                        test_cfg=student_cfg.get('test_cfg'))
+        if init_student:
+            t_checkpoint = _load_checkpoint(teacher_pretrained)
+            all_name = []
+            for name, v in t_checkpoint["state_dict"].items():
+                if name.startswith("backbone."):
+                    continue
+                else:
+                    all_name.append((name, v))
+
+            state_dict = OrderedDict(all_name)
+            load_state_dict(self.student, state_dict)
+
+        self.distill_losses = nn.ModuleDict()
+        self.distill_cfg = distill_cfg
+        for item_loc in distill_cfg:
+            for item_loss in item_loc.methods:
+                loss_name = item_loss.name
+                self.distill_losses[loss_name] = build_distill_loss(item_loss)
+
+    def base_parameters(self):
+        return nn.ModuleList([self.student, self.distill_losses])
+        
+    def discriminator_parameters(self):
+        return self.discriminator
+
+    @property
+    def with_neck(self):
+        """bool: whether the detector has a neck"""
+        return hasattr(self.student, 'neck') and self.student.neck is not None
+
+    # TODO: these properties need to be carefully handled
+    # for both single stage & two stage detectors
+    @property
+    def with_shared_head(self):
+        """bool: whether the detector has a shared head in the RoI Head"""
+        return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
+
+    @property
+    def with_bbox(self):
+        """bool: whether the detector has a bbox head"""
+        return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
+                or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
+
+    @property
+    def with_mask(self):
+        """bool: whether the detector has a mask head"""
+        return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
+                or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
+
+    def init_weights_teacher(self, path=None):
+        """Load the pretrained model in teacher detector.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
+
+
+    def forward_train(self,
+                      img,
+                      img_metas,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      proposals=None,
+                      **kwargs):
+        """
+        Args:
+            img (Tensor): Input images of shape (N, C, H, W).
+                Typically these should be mean centered and std scaled.
+            img_metas (list[dict]): A List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                :class:`mmdet.datasets.pipelines.Collect`.
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
+        """
+        with torch.no_grad():
+            self.teacher.eval()
+            f_t = self.teacher.extract_feat(img)
+
+            proposal_cfg = self.teacher.train_cfg.get('rpn_proposal',
+                                              self.teacher.test_cfg.rpn)
+            t_rpn_losses, t_proposal_list = self.teacher.rpn_head.forward_train(
+                f_t,
+                img_metas,
+                gt_bboxes,
+                gt_labels=None,
+                gt_bboxes_ignore=gt_bboxes_ignore,
+                proposal_cfg=proposal_cfg)
+
+        f_s = self.student.extract_feat(img)
+        losses = dict()
+
+        proposal_cfg = self.student.train_cfg.get('rpn_proposal',
+                                            self.student.test_cfg.rpn)
+        s_rpn_losses, s_proposal_list = self.student.rpn_head.forward_train(
+            f_s,
+            img_metas,
+            gt_bboxes,
+            gt_labels=None,
+            gt_bboxes_ignore=gt_bboxes_ignore,
+            proposal_cfg=proposal_cfg)
+        losses.update(s_rpn_losses)
+
+        s_roi_losses = self.student.roi_head.forward_train(f_s, img_metas, s_proposal_list,
+                                                 gt_bboxes, gt_labels,
+                                                 gt_bboxes_ignore, gt_masks,
+                                                 **kwargs)
+        losses.update(s_roi_losses)
+
+        s_roi_with_pt_losses = self.student.roi_head.forward_train(f_s, img_metas, t_proposal_list,
+                                                 gt_bboxes, gt_labels,
+                                                 gt_bboxes_ignore, gt_masks,
+                                                 **kwargs)
+        for key in s_roi_with_pt_losses:
+            losses["s_w_tp_{}".format(key)] = s_roi_with_pt_losses[key]
+        # losses['s_roi_w_pt_cls'] = s_roi_with_pt_losses['loss_cls']
+        # losses['s_roi_w_pt_bbox'] = s_roi_with_pt_losses['loss_bbox']
+        # losses['s_roi_w_pt_mask'] = s_roi_with_pt_losses['loss_mask']
+
+
+        t_roi_with_ps_losses = self.teacher.roi_head.forward_train(f_t, img_metas, s_proposal_list,
+                                                 gt_bboxes, gt_labels,
+                                                 gt_bboxes_ignore, gt_masks,
+                                                 **kwargs)
+        for key in t_roi_with_ps_losses:
+            losses["t_w_sp_{}".format(key)] = t_roi_with_ps_losses[key]
+        # losses['t_roi_w_ps_cls'] = t_roi_with_ps_losses['loss_cls']
+        # losses['t_roi_w_ps_bbox'] = t_roi_with_ps_losses['loss_bbox']
+        # losses['t_roi_w_ps_mask'] = t_roi_with_ps_losses['loss_mask']
+        return losses
+
+    
+    def simple_test(self, img, img_metas, **kwargs):
+        return self.student.simple_test(img, img_metas, **kwargs)
+
+    def aug_test(self, imgs, img_metas, **kwargs):
+        return self.student.aug_test(img, img_metas, **kwargs)
+
+    def extract_feat(self, imgs):
+        """Extract features from images."""
+        return self.student.extract_feat(imgs)
\ No newline at end of file