diff --git a/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py b/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py new file mode 100644 index 0000000000000000000000000000000000000000..37c66a9e1c0c0fd9be181540c749f6c71c01a6fc --- /dev/null +++ b/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_2x.py @@ -0,0 +1,43 @@ +_base_ = ['./ld_r18_gflv1_r101_fpn_coco_1x.py'] +teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20200630_102002-134b07df.pth' # noqa +model = dict( + pretrained='torchvision://resnet101', + teacher_config='configs/gfl/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py', + teacher_ckpt=teacher_ckpt, + backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5)) + +lr_config = dict(step=[16, 22]) +runner = dict(type='EpochBasedRunner', max_epochs=24) +# multi-scale training +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 800)], + multiscale_mode='range', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +data = dict(train=dict(pipeline=train_pipeline)) diff --git a/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py b/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8ce4a1caf95d7e66e79e14219d3d9a8f74321d --- /dev/null +++ b/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py @@ -0,0 +1,62 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa +model = dict( + type='KnowledgeDistillationSingleStageDetector', + pretrained='torchvision://resnet18', + teacher_config='configs/gfl/gfl_r101_fpn_mstrain_2x_coco.py', + teacher_ckpt=teacher_ckpt, + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[64, 128, 256, 512], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='LDHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25), + loss_ld=dict( + type='KnowledgeDistillationKLDivLoss', loss_weight=0.25, T=10), + reg_max=16, + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/ld/ld_r34_gflv1_r101_fpn_coco_1x.py b/configs/ld/ld_r34_gflv1_r101_fpn_coco_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..905651d1f1d7cd956147111bba6d427e59ce1895 --- /dev/null +++ b/configs/ld/ld_r34_gflv1_r101_fpn_coco_1x.py @@ -0,0 +1,19 @@ +_base_ = ['./ld_r18_gflv1_r101_fpn_coco_1x.py'] +model = dict( + pretrained='torchvision://resnet34', + backbone=dict( + type='ResNet', + depth=34, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[64, 128, 256, 512], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5)) diff --git a/configs/ld/ld_r50_gflv1_r101_fpn_coco_1x.py b/configs/ld/ld_r50_gflv1_r101_fpn_coco_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..923c626363c2f49e8ad15616a09b6cb52260923a --- /dev/null +++ b/configs/ld/ld_r50_gflv1_r101_fpn_coco_1x.py @@ -0,0 +1,19 @@ +_base_ = ['./ld_r18_gflv1_r101_fpn_coco_1x.py'] +model = dict( + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5)) diff --git a/configs/ld/readme.md b/configs/ld/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..ab26e946adee4431ca2b42e0f51c6ebed13e08ec --- /dev/null +++ b/configs/ld/readme.md @@ -0,0 +1,31 @@ +# Localization Distillation for Object Detection + +## Introduction + +[ALGORITHM] + +```latex +@Article{zheng2021LD, + title={Localization Distillation for Object Detection}, + author= {Zhaohui Zheng, Rongguang Ye, Ping Wang, Jun Wang, Dongwei Ren, Wangmeng Zuo}, + journal={arXiv:2102.12252}, + year={2021} +} +``` + +### GFocalV1 with LD + +| Teacher | Student | Training schedule | Mini-batch size | AP (val) | AP50 (val) | AP75 (val) | Config | +| :-------: | :-----: | :---------------: | :-------------: | :------: | :--------: | :--------: | :--------------: | +| -- | R-18 | 1x | 6 | 35.8 | 53.1 | 38.2 | | +| R-101 | R-18 | 1x | 6 | 36.5 | 52.9 | 39.3 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py) | +| -- | R-34 | 1x | 6 | 38.9 | 56.6 | 42.2 | | +| R-101 | R-34 | 1x | 6 | 39.8 | 56.6 | 43.1 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r34_gflv1_r101_fpn_coco_1x.py) | +| -- | R-50 | 1x | 6 | 40.1 | 58.2 | 43.1 | | +| R-101 | R-50 | 1x | 6 | 41.1 | 58.7 | 44.9 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r50_gflv1_r101_fpn_coco_1x.py) | +| -- | R-101 | 2x | 6 | 44.6 | 62.9 | 48.4 | | +| R-101-DCN | R-101 | 2x | 6 | 45.4 | 63.1 | 49.5 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_1x.py) | + +## Note + +- Meaning of Config name: ld_r18(student model)_gflv1(based on gflv1)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 4172cc09ebc366a171424444ec888902d50ef597..f004dd95d97df16167f932587b3ce73b05b04a37 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -13,6 +13,7 @@ from .ga_retina_head import GARetinaHead from .ga_rpn_head import GARPNHead from .gfl_head import GFLHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead +from .ld_head import LDHead from .nasfcos_head import NASFCOSHead from .paa_head import PAAHead from .pisa_retinanet_head import PISARetinaHead @@ -36,5 +37,5 @@ __all__ = [ 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead', 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead', - 'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead' + 'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead' ] diff --git a/mmdet/models/dense_heads/ld_head.py b/mmdet/models/dense_heads/ld_head.py new file mode 100644 index 0000000000000000000000000000000000000000..501e1f7befa086f0b2f818531807411fc383d7bd --- /dev/null +++ b/mmdet/models/dense_heads/ld_head.py @@ -0,0 +1,261 @@ +import torch +from mmcv.runner import force_fp32 + +from mmdet.core import (bbox2distance, bbox_overlaps, distance2bbox, + multi_apply, reduce_mean) +from ..builder import HEADS, build_loss +from .gfl_head import GFLHead + + +@HEADS.register_module() +class LDHead(GFLHead): + """Localization distillation Head. (Short description) + + It utilizes the learned bbox distributions to transfer the localization + dark knowledge from teacher to student. Original paper: `Localization + Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_ + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + loss_ld (dict): Config of Localization Distillation Loss (LD), + T is the temperature for distillation. + """ + + def __init__(self, + num_classes, + in_channels, + loss_ld=dict( + type='LocalizationDistillationLoss', + loss_weight=0.25, + T=10), + **kwargs): + + super(LDHead, self).__init__(num_classes, in_channels, **kwargs) + self.loss_ld = build_loss(loss_ld) + + def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, + bbox_targets, stride, soft_targets, num_total_samples): + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Cls and quality joint scores for each scale + level has shape (N, num_classes, H, W). + bbox_pred (Tensor): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor wight + shape (N, num_total_anchors, 4). + stride (tuple): Stride in this scale level. + num_total_samples (int): Number of positive samples that is + reduced over all GPUs. + + Returns: + dict[tuple, Tensor]: Loss components and weight targets. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(-1, 4 * (self.reg_max + 1)) + soft_targets = soft_targets.permute(0, 2, 3, + 1).reshape(-1, + 4 * (self.reg_max + 1)) + + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + score = label_weights.new_zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] + + weight_targets = cls_score.detach().sigmoid() + weight_targets = weight_targets.max(dim=1)[0][pos_inds] + pos_bbox_pred_corners = self.integral(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_anchor_centers, + pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + score[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) + pos_soft_targets = soft_targets[pos_inds] + soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1) + + target_corners = bbox2distance(pos_anchor_centers, + pos_decode_bbox_targets, + self.reg_max).reshape(-1) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=weight_targets, + avg_factor=1.0) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + + # ld loss + loss_ld = self.loss_ld( + pred_corners, + soft_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + + else: + loss_ld = bbox_pred.sum() * 0 + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + weight_targets = bbox_pred.new_tensor(0) + + # cls (qfl) loss + loss_cls = self.loss_cls( + cls_score, (labels, score), + weight=label_weights, + avg_factor=num_total_samples) + + return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum() + + def forward_train(self, + x, + out_teacher, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=None, + proposal_cfg=None, + **kwargs): + """ + Args: + x (list[Tensor]): Features from FPN. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used + + Returns: + tuple[dict, list]: The loss components and proposals of each image. + + - losses (dict[str, Tensor]): A dictionary of loss components. + - proposal_list (list[Tensor]): Proposals of each image. + """ + outs = self(x) + soft_target = out_teacher[1] + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, soft_target, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas) + losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + if proposal_cfg is None: + return losses + else: + proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) + return losses, proposal_list + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + soft_target, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Cls and quality scores for each scale + level has shape (N, num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.anchor_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets + + num_total_samples = reduce_mean( + torch.tensor(num_total_pos, dtype=torch.float, + device=device)).item() + num_total_samples = max(num_total_samples, 1.0) + + losses_cls, losses_bbox, losses_dfl, losses_ld, \ + avg_factor = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + self.anchor_generator.strides, + soft_target, + num_total_samples=num_total_samples) + + avg_factor = sum(avg_factor) + 1e-6 + avg_factor = reduce_mean(avg_factor).item() + losses_bbox = [x / avg_factor for x in losses_bbox] + losses_dfl = [x / avg_factor for x in losses_dfl] + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_dfl=losses_dfl, + loss_ld=losses_ld) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 5ebf6220398983dc532883b76611c68472720c9c..04011130435cf9fdfadeb821919046b1bddab7d4 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -11,6 +11,7 @@ from .fsaf import FSAF from .gfl import GFL from .grid_rcnn import GridRCNN from .htc import HybridTaskCascade +from .kd_one_stage import KnowledgeDistillationSingleStageDetector from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN from .nasfcos import NASFCOS @@ -29,7 +30,8 @@ from .yolact import YOLACT from .yolo import YOLOV3 __all__ = [ - 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', + 'ATSS', 'BaseDetector', 'SingleStageDetector', + 'KnowledgeDistillationSingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', diff --git a/mmdet/models/detectors/kd_one_stage.py b/mmdet/models/detectors/kd_one_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..671ec19015c87fefd065b84ae887147f90cc892b --- /dev/null +++ b/mmdet/models/detectors/kd_one_stage.py @@ -0,0 +1,100 @@ +import mmcv +import torch +from mmcv.runner import load_checkpoint + +from .. import build_detector +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class KnowledgeDistillationSingleStageDetector(SingleStageDetector): + r"""Implementation of `Distilling the Knowledge in a Neural Network. + <https://arxiv.org/abs/1503.02531>`_. + + Args: + teacher_config (str | dict): Config file path + or the config object of teacher model. + teacher_ckpt (str, optional): Checkpoint path of teacher model. + If left as None, the model will not load any weights. + """ + + def __init__(self, + backbone, + neck, + bbox_head, + teacher_config, + teacher_ckpt=None, + eval_teacher=True, + train_cfg=None, + test_cfg=None, + pretrained=None): + super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg, + pretrained) + self.eval_teacher = eval_teacher + # Build teacher model + if isinstance(teacher_config, str): + teacher_config = mmcv.Config.fromfile(teacher_config) + self.teacher_model = build_detector(teacher_config['model']) + if teacher_ckpt is not None: + load_checkpoint( + self.teacher_model, teacher_ckpt, map_location='cpu') + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(img) + with torch.no_grad(): + teacher_x = self.teacher_model.extract_feat(img) + out_teacher = self.teacher_model.bbox_head(teacher_x) + losses = self.bbox_head.forward_train(x, out_teacher, img_metas, + gt_bboxes, gt_labels, + gt_bboxes_ignore) + return losses + + def cuda(self, device=None): + """Since teacher_model is registered as a plain object, it is necessary + to put the teacher model to cuda when calling cuda function.""" + self.teacher_model.cuda(device=device) + return super().cuda(device=device) + + def train(self, mode=True): + """Set the same train mode for teacher and student model.""" + if self.eval_teacher: + self.teacher_model.train(False) + else: + self.teacher_model.train(mode) + super().train(mode) + + def __setattr__(self, name, value): + """Set attribute, i.e. self.name = value + + This reloading prevent the teacher model from being registered as a + nn.Module. The teacher module is registered as a plain object, so that + the teacher parameters will not show up when calling + ``self.parameters``, ``self.modules``, ``self.children`` methods. + """ + if name == 'teacher_model': + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py index bb887d3735df692aa0c7b3496c18add6b9c52391..297aa228277768eb0ba0e8a377f19704d1feeca8 100644 --- a/mmdet/models/losses/__init__.py +++ b/mmdet/models/losses/__init__.py @@ -9,6 +9,7 @@ from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss from .ghm_loss import GHMC, GHMR from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss, bounded_iou_loss, iou_loss) +from .kd_loss import KnowledgeDistillationKLDivLoss from .mse_loss import MSELoss, mse_loss from .pisa_loss import carl_loss, isr_p from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss @@ -24,5 +25,5 @@ __all__ = [ 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss', 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss', 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss', - 'VarifocalLoss' + 'VarifocalLoss', 'KnowledgeDistillationKLDivLoss' ] diff --git a/mmdet/models/losses/kd_loss.py b/mmdet/models/losses/kd_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f3abb68d4f7b3eec98b873f69c1105a22eb33913 --- /dev/null +++ b/mmdet/models/losses/kd_loss.py @@ -0,0 +1,87 @@ +import mmcv +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import weighted_loss + + +@mmcv.jit(derivate=True, coderize=True) +@weighted_loss +def knowledge_distillation_kl_div_loss(pred, + soft_label, + T, + detach_target=True): + r"""Loss function for knowledge distilling using KL divergence. + + Args: + pred (Tensor): Predicted logits with shape (N, n + 1). + soft_label (Tensor): Target logits with shape (N, N + 1). + T (int): Temperature for distillation. + detach_target (bool): Remove soft_label from automatic differentiation + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + assert pred.size() == soft_label.size() + target = F.softmax(soft_label / T, dim=1) + if detach_target: + target = target.detach() + + kd_loss = F.kl_div( + F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * ( + T * T) + + return kd_loss + + +@LOSSES.register_module() +class KnowledgeDistillationKLDivLoss(nn.Module): + """Loss function for knowledge distilling using KL divergence. + + Args: + reduction (str): Options are `'none'`, `'mean'` and `'sum'`. + loss_weight (float): Loss weight of current loss. + T (int): Temperature for distillation. + """ + + def __init__(self, reduction='mean', loss_weight=1.0, T=10): + super(KnowledgeDistillationKLDivLoss, self).__init__() + assert T >= 1 + self.reduction = reduction + self.loss_weight = loss_weight + self.T = T + + def forward(self, + pred, + soft_label, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (Tensor): Predicted logits with shape (N, n + 1). + soft_label (Tensor): Target logits with shape (N, N + 1). + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + + reduction = ( + reduction_override if reduction_override else self.reduction) + + loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss( + pred, + soft_label, + weight, + reduction=reduction, + avg_factor=avg_factor, + T=self.T) + + return loss_kd diff --git a/tests/test_metrics/test_losses.py b/tests/test_metrics/test_losses.py index 8a85cee43cda5d803f0c91d0736f7bdf78ab2f4b..5370f0eb90a1e9fb60a73d0579193827bc270d3f 100644 --- a/tests/test_metrics/test_losses.py +++ b/tests/test_metrics/test_losses.py @@ -78,6 +78,37 @@ def test_varifocal_loss(): loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0)) +def test_kd_loss(): + # test that temeprature should be greater than 1 + with pytest.raises(AssertionError): + loss_cfg = dict( + type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=0.5) + build_loss(loss_cfg) + + # test that pred and target should be of the same size + loss_cls_cfg = dict( + type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=1) + loss_cls = build_loss(loss_cls_cfg) + with pytest.raises(AssertionError): + fake_pred = torch.Tensor([[100, -100]]) + fake_label = torch.Tensor([1]).long() + loss_cls(fake_pred, fake_label) + + # test the calculation + loss_cls = build_loss(loss_cls_cfg) + fake_pred = torch.Tensor([[100.0, 100.0]]) + fake_target = torch.Tensor([[1.0, 1.0]]) + assert torch.allclose(loss_cls(fake_pred, fake_target), torch.tensor(0.0)) + + # test the loss with weights + loss_cls = build_loss(loss_cls_cfg) + fake_pred = torch.Tensor([[100.0, -100.0], [100.0, 100.0]]) + fake_target = torch.Tensor([[1.0, 0.0], [1.0, 1.0]]) + fake_weight = torch.Tensor([0.0, 1.0]) + assert torch.allclose( + loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0)) + + def test_accuracy(): # test for empty pred pred = torch.empty(0, 4) diff --git a/tests/test_models/test_dense_heads/test_ld_head.py b/tests/test_models/test_dense_heads/test_ld_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7541adc762a9e0bbb367ceb56bc74fb2ae8789 --- /dev/null +++ b/tests/test_models/test_dense_heads/test_ld_head.py @@ -0,0 +1,120 @@ +import mmcv +import torch + +from mmdet.models.dense_heads import GFLHead, LDHead + + +def test_ld_head_loss(): + """Tests vfnet head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + train_cfg = mmcv.Config( + dict( + assigner=dict(type='ATSSAssigner', topk=9, ignore_iof_thr=0.1), + allowed_border=-1, + pos_weight=-1, + debug=False)) + + self = LDHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_ld=dict(type='KnowledgeDistillationKLDivLoss', loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128])) + + teacher_model = GFLHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128])) + + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + cls_scores, bbox_preds = self.forward(feat) + rand_soft_target = teacher_model.forward(feat)[1] + + # Test that empty ground truth encourages the network to predict + # background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + rand_soft_target, img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero, ld loss should + # be non-negative but there should be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + empty_ld_loss = sum(empty_gt_losses['loss_ld']) + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + assert empty_ld_loss.item() >= 0, 'ld loss should be non-negative' + + # When truth is non-empty then both cls and box loss should be nonzero + # for random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + rand_soft_target, img_metas, gt_bboxes_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' + + gt_bboxes_ignore = gt_bboxes + + # When truth is non-empty but ignored then the cls loss should be nonzero, + # but there should be no box loss. + ignore_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + rand_soft_target, img_metas, gt_bboxes_ignore) + ignore_cls_loss = sum(ignore_gt_losses['loss_cls']) + ignore_box_loss = sum(ignore_gt_losses['loss_bbox']) + + assert ignore_cls_loss.item() > 0, 'cls loss should be non-zero' + assert ignore_box_loss.item() == 0, 'gt bbox ignored loss should be zero' + + # When truth is non-empty and not ignored then both cls and box loss should + # be nonzero for random inputs + gt_bboxes_ignore = [torch.randn(1, 4)] + + not_ignore_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, + gt_labels, rand_soft_target, img_metas, + gt_bboxes_ignore) + not_ignore_cls_loss = sum(not_ignore_gt_losses['loss_cls']) + not_ignore_box_loss = sum(not_ignore_gt_losses['loss_bbox']) + + assert not_ignore_cls_loss.item() > 0, 'cls loss should be non-zero' + assert not_ignore_box_loss.item( + ) > 0, 'gt bbox not ignored loss should be non-zero' diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index 4e5589c8076c4c74dce58e82c7dd68ab7067a960..416bf73192e492fcfc3d74e6c3d6cafd00ab386b 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -495,6 +495,61 @@ def test_detr_forward(): batch_results.append(result) +def test_kd_single_stage_forward(): + model = _get_detector_cfg('ld/ld_r18_gflv1_r101_fpn_coco_1x.py') + model['pretrained'] = None + + from mmdet.models import build_detector + detector = build_detector(model) + + input_shape = (1, 3, 100, 100) + mm_inputs = _demo_mm_inputs(input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + # Test forward train with non-empty truth batch + detector.train() + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward train with an empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward test + detector.eval() + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + rescale=True, + return_loss=False) + batch_results.append(result) + + def test_inference_detector(): from mmdet.apis import inference_detector from mmdet.models import build_detector