Skip to content
Snippets Groups Projects
Unverified Commit ecb4a07c authored by HikariTJU's avatar HikariTJU Committed by GitHub
Browse files

[Feature]: Add Localization Distillation for Object Detection (#4758)

* add Localization Distillation for Object Detection https://arxiv.org/abs/2102.12252

* fix lint

* fix lint

* fix lint

* fix lint

* fix lint

* fix config

* add kd detector

* edit loss name

* overide setattr but failed

* move init_detector into init

* seperate ld and gfocal, fix unused param error

* small fix

* small fix

* add test, reload magic function, create kd_loss.py

* del teacher model warnings

* fix reference

* add ignored bbox test fix docstring

* small fix

* docstring fix

* change names

* fix

* fix

* fix test

* fix import

* fix

* docstring fix

* fix

* retest

* add test

* docstring fix
parent 8db767f7
No related branches found
No related tags found
No related merge requests found
Showing
with 835 additions and 3 deletions
_base_ = ['./ld_r18_gflv1_r101_fpn_coco_1x.py']
teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20200630_102002-134b07df.pth' # noqa
model = dict(
pretrained='torchvision://resnet101',
teacher_config='configs/gfl/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py',
teacher_ckpt=teacher_ckpt,
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5))
lr_config = dict(step=[16, 22])
runner = dict(type='EpochBasedRunner', max_epochs=24)
# multi-scale training
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 800)],
multiscale_mode='range',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
data = dict(train=dict(pipeline=train_pipeline))
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa
model = dict(
type='KnowledgeDistillationSingleStageDetector',
pretrained='torchvision://resnet18',
teacher_config='configs/gfl/gfl_r101_fpn_mstrain_2x_coco.py',
teacher_ckpt=teacher_ckpt,
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[64, 128, 256, 512],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='LDHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
loss_ld=dict(
type='KnowledgeDistillationKLDivLoss', loss_weight=0.25, T=10),
reg_max=16,
loss_bbox=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
_base_ = ['./ld_r18_gflv1_r101_fpn_coco_1x.py']
model = dict(
pretrained='torchvision://resnet34',
backbone=dict(
type='ResNet',
depth=34,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[64, 128, 256, 512],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5))
_base_ = ['./ld_r18_gflv1_r101_fpn_coco_1x.py']
model = dict(
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5))
# Localization Distillation for Object Detection
## Introduction
[ALGORITHM]
```latex
@Article{zheng2021LD,
title={Localization Distillation for Object Detection},
author= {Zhaohui Zheng, Rongguang Ye, Ping Wang, Jun Wang, Dongwei Ren, Wangmeng Zuo},
journal={arXiv:2102.12252},
year={2021}
}
```
### GFocalV1 with LD
| Teacher | Student | Training schedule | Mini-batch size | AP (val) | AP50 (val) | AP75 (val) | Config |
| :-------: | :-----: | :---------------: | :-------------: | :------: | :--------: | :--------: | :--------------: |
| -- | R-18 | 1x | 6 | 35.8 | 53.1 | 38.2 | |
| R-101 | R-18 | 1x | 6 | 36.5 | 52.9 | 39.3 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r18_gflv1_r101_fpn_coco_1x.py) |
| -- | R-34 | 1x | 6 | 38.9 | 56.6 | 42.2 | |
| R-101 | R-34 | 1x | 6 | 39.8 | 56.6 | 43.1 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r34_gflv1_r101_fpn_coco_1x.py) |
| -- | R-50 | 1x | 6 | 40.1 | 58.2 | 43.1 | |
| R-101 | R-50 | 1x | 6 | 41.1 | 58.7 | 44.9 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r50_gflv1_r101_fpn_coco_1x.py) |
| -- | R-101 | 2x | 6 | 44.6 | 62.9 | 48.4 | |
| R-101-DCN | R-101 | 2x | 6 | 45.4 | 63.1 | 49.5 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/ld/ld_r101_gflv1_r101dcn_fpn_coco_1x.py) |
## Note
- Meaning of Config name: ld_r18(student model)_gflv1(based on gflv1)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py
...@@ -13,6 +13,7 @@ from .ga_retina_head import GARetinaHead ...@@ -13,6 +13,7 @@ from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead from .ga_rpn_head import GARPNHead
from .gfl_head import GFLHead from .gfl_head import GFLHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .ld_head import LDHead
from .nasfcos_head import NASFCOSHead from .nasfcos_head import NASFCOSHead
from .paa_head import PAAHead from .paa_head import PAAHead
from .pisa_retinanet_head import PISARetinaHead from .pisa_retinanet_head import PISARetinaHead
...@@ -36,5 +37,5 @@ __all__ = [ ...@@ -36,5 +37,5 @@ __all__ = [
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead', 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead', 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead',
'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead' 'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead'
] ]
import torch
from mmcv.runner import force_fp32
from mmdet.core import (bbox2distance, bbox_overlaps, distance2bbox,
multi_apply, reduce_mean)
from ..builder import HEADS, build_loss
from .gfl_head import GFLHead
@HEADS.register_module()
class LDHead(GFLHead):
"""Localization distillation Head. (Short description)
It utilizes the learned bbox distributions to transfer the localization
dark knowledge from teacher to student. Original paper: `Localization
Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
loss_ld (dict): Config of Localization Distillation Loss (LD),
T is the temperature for distillation.
"""
def __init__(self,
num_classes,
in_channels,
loss_ld=dict(
type='LocalizationDistillationLoss',
loss_weight=0.25,
T=10),
**kwargs):
super(LDHead, self).__init__(num_classes, in_channels, **kwargs)
self.loss_ld = build_loss(loss_ld)
def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
bbox_targets, stride, soft_targets, num_total_samples):
"""Compute loss of a single scale level.
Args:
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
cls_score (Tensor): Cls and quality joint scores for each scale
level has shape (N, num_classes, H, W).
bbox_pred (Tensor): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor wight
shape (N, num_total_anchors, 4).
stride (tuple): Stride in this scale level.
num_total_samples (int): Number of positive samples that is
reduced over all GPUs.
Returns:
dict[tuple, Tensor]: Loss components and weight targets.
"""
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, 4 * (self.reg_max + 1))
soft_targets = soft_targets.permute(0, 2, 3,
1).reshape(-1,
4 * (self.reg_max + 1))
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
score = label_weights.new_zeros(labels.shape)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
weight_targets = cls_score.detach().sigmoid()
weight_targets = weight_targets.max(dim=1)[0][pos_inds]
pos_bbox_pred_corners = self.integral(pos_bbox_pred)
pos_decode_bbox_pred = distance2bbox(pos_anchor_centers,
pos_bbox_pred_corners)
pos_decode_bbox_targets = pos_bbox_targets / stride[0]
score[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_decode_bbox_targets,
is_aligned=True)
pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
pos_soft_targets = soft_targets[pos_inds]
soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
target_corners = bbox2distance(pos_anchor_centers,
pos_decode_bbox_targets,
self.reg_max).reshape(-1)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=weight_targets,
avg_factor=1.0)
# dfl loss
loss_dfl = self.loss_dfl(
pred_corners,
target_corners,
weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
avg_factor=4.0)
# ld loss
loss_ld = self.loss_ld(
pred_corners,
soft_corners,
weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
avg_factor=4.0)
else:
loss_ld = bbox_pred.sum() * 0
loss_bbox = bbox_pred.sum() * 0
loss_dfl = bbox_pred.sum() * 0
weight_targets = bbox_pred.new_tensor(0)
# cls (qfl) loss
loss_cls = self.loss_cls(
cls_score, (labels, score),
weight=label_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
def forward_train(self,
x,
out_teacher,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs):
"""
Args:
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns:
tuple[dict, list]: The loss components and proposals of each image.
- losses (dict[str, Tensor]): A dictionary of loss components.
- proposal_list (list[Tensor]): Proposals of each image.
"""
outs = self(x)
soft_target = out_teacher[1]
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, soft_target, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
soft_target,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Cls and quality scores for each scale
level has shape (N, num_classes, H, W).
bbox_preds (list[Tensor]): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos, dtype=torch.float,
device=device)).item()
num_total_samples = max(num_total_samples, 1.0)
losses_cls, losses_bbox, losses_dfl, losses_ld, \
avg_factor = multi_apply(
self.loss_single,
anchor_list,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
self.anchor_generator.strides,
soft_target,
num_total_samples=num_total_samples)
avg_factor = sum(avg_factor) + 1e-6
avg_factor = reduce_mean(avg_factor).item()
losses_bbox = [x / avg_factor for x in losses_bbox]
losses_dfl = [x / avg_factor for x in losses_dfl]
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_dfl=losses_dfl,
loss_ld=losses_ld)
...@@ -11,6 +11,7 @@ from .fsaf import FSAF ...@@ -11,6 +11,7 @@ from .fsaf import FSAF
from .gfl import GFL from .gfl import GFL
from .grid_rcnn import GridRCNN from .grid_rcnn import GridRCNN
from .htc import HybridTaskCascade from .htc import HybridTaskCascade
from .kd_one_stage import KnowledgeDistillationSingleStageDetector
from .mask_rcnn import MaskRCNN from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN from .mask_scoring_rcnn import MaskScoringRCNN
from .nasfcos import NASFCOS from .nasfcos import NASFCOS
...@@ -29,7 +30,8 @@ from .yolact import YOLACT ...@@ -29,7 +30,8 @@ from .yolact import YOLACT
from .yolo import YOLOV3 from .yolo import YOLOV3
__all__ = [ __all__ = [
'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'ATSS', 'BaseDetector', 'SingleStageDetector',
'KnowledgeDistillationSingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
......
import mmcv
import torch
from mmcv.runner import load_checkpoint
from .. import build_detector
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class KnowledgeDistillationSingleStageDetector(SingleStageDetector):
r"""Implementation of `Distilling the Knowledge in a Neural Network.
<https://arxiv.org/abs/1503.02531>`_.
Args:
teacher_config (str | dict): Config file path
or the config object of teacher model.
teacher_ckpt (str, optional): Checkpoint path of teacher model.
If left as None, the model will not load any weights.
"""
def __init__(self,
backbone,
neck,
bbox_head,
teacher_config,
teacher_ckpt=None,
eval_teacher=True,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)
self.eval_teacher = eval_teacher
# Build teacher model
if isinstance(teacher_config, str):
teacher_config = mmcv.Config.fromfile(teacher_config)
self.teacher_model = build_detector(teacher_config['model'])
if teacher_ckpt is not None:
load_checkpoint(
self.teacher_model, teacher_ckpt, map_location='cpu')
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
with torch.no_grad():
teacher_x = self.teacher_model.extract_feat(img)
out_teacher = self.teacher_model.bbox_head(teacher_x)
losses = self.bbox_head.forward_train(x, out_teacher, img_metas,
gt_bboxes, gt_labels,
gt_bboxes_ignore)
return losses
def cuda(self, device=None):
"""Since teacher_model is registered as a plain object, it is necessary
to put the teacher model to cuda when calling cuda function."""
self.teacher_model.cuda(device=device)
return super().cuda(device=device)
def train(self, mode=True):
"""Set the same train mode for teacher and student model."""
if self.eval_teacher:
self.teacher_model.train(False)
else:
self.teacher_model.train(mode)
super().train(mode)
def __setattr__(self, name, value):
"""Set attribute, i.e. self.name = value
This reloading prevent the teacher model from being registered as a
nn.Module. The teacher module is registered as a plain object, so that
the teacher parameters will not show up when calling
``self.parameters``, ``self.modules``, ``self.children`` methods.
"""
if name == 'teacher_model':
object.__setattr__(self, name, value)
else:
super().__setattr__(name, value)
...@@ -9,6 +9,7 @@ from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss ...@@ -9,6 +9,7 @@ from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
from .ghm_loss import GHMC, GHMR from .ghm_loss import GHMC, GHMR
from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss, from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss,
bounded_iou_loss, iou_loss) bounded_iou_loss, iou_loss)
from .kd_loss import KnowledgeDistillationKLDivLoss
from .mse_loss import MSELoss, mse_loss from .mse_loss import MSELoss, mse_loss
from .pisa_loss import carl_loss, isr_p from .pisa_loss import carl_loss, isr_p
from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss
...@@ -24,5 +25,5 @@ __all__ = [ ...@@ -24,5 +25,5 @@ __all__ = [
'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss', 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss', 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss', 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
'VarifocalLoss' 'VarifocalLoss', 'KnowledgeDistillationKLDivLoss'
] ]
import mmcv
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weighted_loss
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def knowledge_distillation_kl_div_loss(pred,
soft_label,
T,
detach_target=True):
r"""Loss function for knowledge distilling using KL divergence.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
T (int): Temperature for distillation.
detach_target (bool): Remove soft_label from automatic differentiation
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
assert pred.size() == soft_label.size()
target = F.softmax(soft_label / T, dim=1)
if detach_target:
target = target.detach()
kd_loss = F.kl_div(
F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
T * T)
return kd_loss
@LOSSES.register_module()
class KnowledgeDistillationKLDivLoss(nn.Module):
"""Loss function for knowledge distilling using KL divergence.
Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
T (int): Temperature for distillation.
"""
def __init__(self, reduction='mean', loss_weight=1.0, T=10):
super(KnowledgeDistillationKLDivLoss, self).__init__()
assert T >= 1
self.reduction = reduction
self.loss_weight = loss_weight
self.T = T
def forward(self,
pred,
soft_label,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
pred,
soft_label,
weight,
reduction=reduction,
avg_factor=avg_factor,
T=self.T)
return loss_kd
...@@ -78,6 +78,37 @@ def test_varifocal_loss(): ...@@ -78,6 +78,37 @@ def test_varifocal_loss():
loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0)) loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0))
def test_kd_loss():
# test that temeprature should be greater than 1
with pytest.raises(AssertionError):
loss_cfg = dict(
type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=0.5)
build_loss(loss_cfg)
# test that pred and target should be of the same size
loss_cls_cfg = dict(
type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=1)
loss_cls = build_loss(loss_cls_cfg)
with pytest.raises(AssertionError):
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
loss_cls(fake_pred, fake_label)
# test the calculation
loss_cls = build_loss(loss_cls_cfg)
fake_pred = torch.Tensor([[100.0, 100.0]])
fake_target = torch.Tensor([[1.0, 1.0]])
assert torch.allclose(loss_cls(fake_pred, fake_target), torch.tensor(0.0))
# test the loss with weights
loss_cls = build_loss(loss_cls_cfg)
fake_pred = torch.Tensor([[100.0, -100.0], [100.0, 100.0]])
fake_target = torch.Tensor([[1.0, 0.0], [1.0, 1.0]])
fake_weight = torch.Tensor([0.0, 1.0])
assert torch.allclose(
loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0))
def test_accuracy(): def test_accuracy():
# test for empty pred # test for empty pred
pred = torch.empty(0, 4) pred = torch.empty(0, 4)
......
import mmcv
import torch
from mmdet.models.dense_heads import GFLHead, LDHead
def test_ld_head_loss():
"""Tests vfnet head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
train_cfg = mmcv.Config(
dict(
assigner=dict(type='ATSSAssigner', topk=9, ignore_iof_thr=0.1),
allowed_border=-1,
pos_weight=-1,
debug=False))
self = LDHead(
num_classes=4,
in_channels=1,
train_cfg=train_cfg,
loss_ld=dict(type='KnowledgeDistillationKLDivLoss', loss_weight=1.0),
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]))
teacher_model = GFLHead(
num_classes=4,
in_channels=1,
train_cfg=train_cfg,
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]))
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
]
cls_scores, bbox_preds = self.forward(feat)
rand_soft_target = teacher_model.forward(feat)[1]
# Test that empty ground truth encourages the network to predict
# background
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
rand_soft_target, img_metas, gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero, ld loss should
# be non-negative but there should be no box loss.
empty_cls_loss = sum(empty_gt_losses['loss_cls'])
empty_box_loss = sum(empty_gt_losses['loss_bbox'])
empty_ld_loss = sum(empty_gt_losses['loss_ld'])
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_box_loss.item() == 0, (
'there should be no box loss when there are no true boxes')
assert empty_ld_loss.item() >= 0, 'ld loss should be non-negative'
# When truth is non-empty then both cls and box loss should be nonzero
# for random inputs
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
rand_soft_target, img_metas, gt_bboxes_ignore)
onegt_cls_loss = sum(one_gt_losses['loss_cls'])
onegt_box_loss = sum(one_gt_losses['loss_bbox'])
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_box_loss.item() > 0, 'box loss should be non-zero'
gt_bboxes_ignore = gt_bboxes
# When truth is non-empty but ignored then the cls loss should be nonzero,
# but there should be no box loss.
ignore_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
rand_soft_target, img_metas, gt_bboxes_ignore)
ignore_cls_loss = sum(ignore_gt_losses['loss_cls'])
ignore_box_loss = sum(ignore_gt_losses['loss_bbox'])
assert ignore_cls_loss.item() > 0, 'cls loss should be non-zero'
assert ignore_box_loss.item() == 0, 'gt bbox ignored loss should be zero'
# When truth is non-empty and not ignored then both cls and box loss should
# be nonzero for random inputs
gt_bboxes_ignore = [torch.randn(1, 4)]
not_ignore_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes,
gt_labels, rand_soft_target, img_metas,
gt_bboxes_ignore)
not_ignore_cls_loss = sum(not_ignore_gt_losses['loss_cls'])
not_ignore_box_loss = sum(not_ignore_gt_losses['loss_bbox'])
assert not_ignore_cls_loss.item() > 0, 'cls loss should be non-zero'
assert not_ignore_box_loss.item(
) > 0, 'gt bbox not ignored loss should be non-zero'
...@@ -495,6 +495,61 @@ def test_detr_forward(): ...@@ -495,6 +495,61 @@ def test_detr_forward():
batch_results.append(result) batch_results.append(result)
def test_kd_single_stage_forward():
model = _get_detector_cfg('ld/ld_r18_gflv1_r101_fpn_coco_1x.py')
model['pretrained'] = None
from mmdet.models import build_detector
detector = build_detector(model)
input_shape = (1, 3, 100, 100)
mm_inputs = _demo_mm_inputs(input_shape)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
# Test forward train with non-empty truth batch
detector.train()
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
loss, _ = detector._parse_losses(losses)
assert float(loss.item()) > 0
# Test forward train with an empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
loss, _ = detector._parse_losses(losses)
assert float(loss.item()) > 0
# Test forward test
detector.eval()
with torch.no_grad():
img_list = [g[None, :] for g in imgs]
batch_results = []
for one_img, one_meta in zip(img_list, img_metas):
result = detector.forward([one_img], [[one_meta]],
rescale=True,
return_loss=False)
batch_results.append(result)
def test_inference_detector(): def test_inference_detector():
from mmdet.apis import inference_detector from mmdet.apis import inference_detector
from mmdet.models import build_detector from mmdet.models import build_detector
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment