diff --git a/configs/distillers/mimic_head/mb_mask_rcnn_swinT_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py b/configs/distillers/mimic_head/mb_mask_rcnn_swinT_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..2e897355e4c43c803911489ec8a0b2204d756369 --- /dev/null +++ b/configs/distillers/mimic_head/mb_mask_rcnn_swinT_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py @@ -0,0 +1,77 @@ +_base_ = [ + '../../_base_/models/mask_rcnn_swin_fpn.py', + '../../_base_/datasets/coco_instance.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +find_unused_parameters=True +weight=1 +distiller = dict( + type='HeadDistiller', + teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_tiny_patch4_window7.pth', + init_student = True, + distill_cfg = [ dict(feature_level = 0, + methods=[dict(type='MSELoss', + name='loss_mb_0', + student_channels = 96, + teacher_channels = 96, + weight = weight, + ) + ] + ), + dict(feature_level = 1, + methods=[dict(type='MSELoss', + name='loss_mb_1', + student_channels = 192, + teacher_channels = 192, + weight = weight, + ) + ] + ), + dict(feature_level = 2, + methods=[dict(type='MSELoss', + name='loss_mb_2', + student_channels = 384, + teacher_channels = 384, + weight = weight, + ) + ] + ), + dict(feature_level = 3, + methods=[dict(type='MSELoss', + name='loss_mb_3', + student_channels = 768, + teacher_channels = 768, + weight = weight, + ) + ] + ), + ] + ) + +student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py' +teacher_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py' + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2,) +#data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) +lr_config = dict(step=[8, 11]) +#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) +runner = dict(type='EpochBasedRunner', max_epochs=12) +# do not use mmdet version fp16 +# fp16 = None +# optimizer_config = dict( +# type="DistOptimizerHook", +# update_interval=1, +# grad_clip=None, +# coalesce=True, +# bucket_size_mb=-1, +# use_fp16=True, +# ) \ No newline at end of file diff --git a/mmdet/distillation/distillers/__init__.py b/mmdet/distillation/distillers/__init__.py index a291ac1fb2cd5ff39b3aa34b6702024a2806e912..4dc0dd0dc8f8fef9e029c5aef3dc80f038ad9b0e 100644 --- a/mmdet/distillation/distillers/__init__.py +++ b/mmdet/distillation/distillers/__init__.py @@ -1,9 +1,11 @@ from .detection_distiller import DetectionDistiller from .backbone_distiller import BackboneDistiller from .fpn_distiller import FPNDistiller +from .head_distiller import HeadDistiller __all__ = [ 'DetectionDistiller', 'BackboneDistiller', - 'FPNDistiller' + 'FPNDistiller', + 'HeadDistiller' ] \ No newline at end of file diff --git a/mmdet/distillation/distillers/head_distiller.py b/mmdet/distillation/distillers/head_distiller.py new file mode 100644 index 0000000000000000000000000000000000000000..3da811209b625abfd10a1418a06cf329d8a0616e --- /dev/null +++ b/mmdet/distillation/distillers/head_distiller.py @@ -0,0 +1,181 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +from mmdet.models.detectors.base import BaseDetector +from mmdet.models import build_detector +from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict +from ..builder import DISTILLER,build_distill_loss +from collections import OrderedDict + + + + +@DISTILLER.register_module() +class HeadDistiller(BaseDetector): + """Head distiller for detectors. + It typically consists of teacher_model and student_model. + """ + def __init__(self, + teacher_cfg, + student_cfg, + distill_cfg=None, + teacher_pretrained=None, + init_student=False): + + super(HeadDistiller, self).__init__() + + self.teacher = build_detector(teacher_cfg.model, + train_cfg=teacher_cfg.get('train_cfg'), + test_cfg=teacher_cfg.get('test_cfg')) + self.init_weights_teacher(teacher_pretrained) + + + self.teacher.eval() + self.student= build_detector(student_cfg.model, + train_cfg=student_cfg.get('train_cfg'), + test_cfg=student_cfg.get('test_cfg')) + if init_student: + t_checkpoint = _load_checkpoint(teacher_pretrained) + all_name = [] + for name, v in t_checkpoint["state_dict"].items(): + if name.startswith("backbone."): + continue + else: + all_name.append((name, v)) + + state_dict = OrderedDict(all_name) + load_state_dict(self.student, state_dict) + + self.distill_losses = nn.ModuleDict() + self.distill_cfg = distill_cfg + for item_loc in distill_cfg: + for item_loss in item_loc.methods: + loss_name = item_loss.name + self.distill_losses[loss_name] = build_distill_loss(item_loss) + + def base_parameters(self): + return nn.ModuleList([self.student, self.distill_losses]) + + def discriminator_parameters(self): + return self.discriminator + + @property + def with_neck(self): + """bool: whether the detector has a neck""" + return hasattr(self.student, 'neck') and self.student.neck is not None + + # TODO: these properties need to be carefully handled + # for both single stage & two stage detectors + @property + def with_shared_head(self): + """bool: whether the detector has a shared head in the RoI Head""" + return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head + + @property + def with_bbox(self): + """bool: whether the detector has a bbox head""" + return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox) + or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None)) + + @property + def with_mask(self): + """bool: whether the detector has a mask head""" + return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask) + or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None)) + + def init_weights_teacher(self, path=None): + """Load the pretrained model in teacher detector. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + checkpoint = load_checkpoint(self.teacher, path, map_location='cpu') + + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + gt_masks=None, + proposals=None, + **kwargs): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + Returns: + dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses). + """ + with torch.no_grad(): + self.teacher.eval() + f_t = self.teacher.extract_feat(img) + + proposal_cfg = self.teacher.train_cfg.get('rpn_proposal', + self.teacher.test_cfg.rpn) + t_rpn_losses, t_proposal_list = self.teacher.rpn_head.forward_train( + f_t, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=gt_bboxes_ignore, + proposal_cfg=proposal_cfg) + + f_s = self.student.extract_feat(img) + losses = dict() + + proposal_cfg = self.student.train_cfg.get('rpn_proposal', + self.student.test_cfg.rpn) + s_rpn_losses, s_proposal_list = self.student.rpn_head.forward_train( + f_s, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=gt_bboxes_ignore, + proposal_cfg=proposal_cfg) + losses.update(s_rpn_losses) + + s_roi_losses = self.student.roi_head.forward_train(f_s, img_metas, s_proposal_list, + gt_bboxes, gt_labels, + gt_bboxes_ignore, gt_masks, + **kwargs) + losses.update(s_roi_losses) + + s_roi_with_pt_losses = self.student.roi_head.forward_train(f_s, img_metas, t_proposal_list, + gt_bboxes, gt_labels, + gt_bboxes_ignore, gt_masks, + **kwargs) + for key in s_roi_with_pt_losses: + losses["s_w_tp_{}".format(key)] = s_roi_with_pt_losses[key] + # losses['s_roi_w_pt_cls'] = s_roi_with_pt_losses['loss_cls'] + # losses['s_roi_w_pt_bbox'] = s_roi_with_pt_losses['loss_bbox'] + # losses['s_roi_w_pt_mask'] = s_roi_with_pt_losses['loss_mask'] + + + t_roi_with_ps_losses = self.teacher.roi_head.forward_train(f_t, img_metas, s_proposal_list, + gt_bboxes, gt_labels, + gt_bboxes_ignore, gt_masks, + **kwargs) + for key in t_roi_with_ps_losses: + losses["t_w_sp_{}".format(key)] = t_roi_with_ps_losses[key] + # losses['t_roi_w_ps_cls'] = t_roi_with_ps_losses['loss_cls'] + # losses['t_roi_w_ps_bbox'] = t_roi_with_ps_losses['loss_bbox'] + # losses['t_roi_w_ps_mask'] = t_roi_with_ps_losses['loss_mask'] + return losses + + + def simple_test(self, img, img_metas, **kwargs): + return self.student.simple_test(img, img_metas, **kwargs) + + def aug_test(self, imgs, img_metas, **kwargs): + return self.student.aug_test(img, img_metas, **kwargs) + + def extract_feat(self, imgs): + """Extract features from images.""" + return self.student.extract_feat(imgs) \ No newline at end of file