Newer
Older
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmdet.models.detectors.base import BaseDetector
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
from ..builder import DISTILLER,build_distill_loss
from collections import OrderedDict
@DISTILLER.register_module()
class FGFI(BaseDetector):
"""FGFI distiller for detectors.
It typically consists of teacher_model and student_model.
"""
def __init__(self,
teacher_cfg,
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
super(FGFI, self).__init__()
self.teacher = build_detector(teacher_cfg.model,
train_cfg=teacher_cfg.get('train_cfg'),
test_cfg=teacher_cfg.get('test_cfg'))
self.init_weights_teacher(teacher_pretrained)
self.teacher.eval()
self.student= build_detector(student_cfg.model,
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
self.distill_losses = nn.ModuleDict()
self.distill_cfg = distill_cfg
for item_loc in distill_cfg:
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
def base_parameters(self):
return nn.ModuleList([self.student, self.distill_losses])
def discriminator_parameters(self):
return self.discriminator
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self.student, 'neck') and self.student.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
def init_weights_teacher(self, path=None):
"""Load the pretrained model in teacher detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def init_weights_student(self, load_teacher_part, teacher_pretrained):
self.student.init_weights()
if load_teacher_part:
assert load_teacher_part in ['neck', 'head', 'neck_head']
def check_key(key, load_teacher_part):
if 'neck' in key and 'neck' in load_teacher_part:
return True
elif 'head' in key and 'head' in load_teacher_part:
return True
else:
return False
t_checkpoint = _load_checkpoint(teacher_pretrained, map_location='cpu')
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if check_key(name, load_teacher_part):
all_name.append((name, v))
state_dict = OrderedDict(all_name)
load_state_dict(self.student, state_dict)
def _map_roi_levels(self, rois, num_levels):
scale = torch.sqrt(
(rois[:, 2] - rois[:, 0] + 1) * (rois[:, 3] - rois[:, 1] + 1))
target_lvls = torch.floor(torch.log2(scale / 56 + 1e-6))
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
def get_roi_mask(self, cls_scores, img_metas, gt_bboxes, phi=0.5):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
from mmdet.core import bbox_overlaps
with torch.no_grad():
anchor_list, _ = self.student.rpn_head.get_anchors(featmap_sizes, img_metas)
mask_batch = []
for batch in range(len(gt_bboxes)):
mask_level = []
target_lvls = self._map_roi_levels(gt_bboxes[batch], len(anchor_list[batch]))
for level in range(len(anchor_list[batch])):
gt_level = gt_bboxes[batch][target_lvls==level]
h, w = featmap_sizes[level][0], featmap_sizes[level][1]
if gt_level.shape[0] > 0:
IoU_map = bbox_overlaps(anchor_list[batch][level], gt_level)
max_iou, _ = torch.max(IoU_map, dim=0)
IoU_map = IoU_map.view(h, w, self.student.rpn_head.num_anchors, -1)
for ins in range(gt_level.shape[0]):
max_iou_per_gt = max_iou[ins] * phi
mask_per_gt = torch.sum(IoU_map[:,:,:,ins] > max_iou_per_gt, dim = 2)
mask_per_img += mask_per_gt
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
mask_level.append(mask_per_img)
mask_batch.append(mask_level)
mask_batch_level = []
for i in range(len(mask_batch[0])):
tmp = []
for batch in range(len(mask_batch)):
tmp.append(mask_batch[batch][i])
mask_batch_level.append(torch.stack(tmp, dim=0))
return mask_batch_level
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
"""
losses = dict()
with torch.no_grad():
self.teacher.eval()
f_t = self.teacher.backbone(img)
f_t = self.teacher.neck(f_t)
f_s = self.student.backbone(img)
f_s = self.student.neck(f_s)
rpn_outs = self.student.rpn_head(f_s)
loss_inputs = rpn_outs + (gt_bboxes, img_metas)
rpn_losses = self.student.rpn_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
proposal_cfg = self.student.train_cfg.get('rpn_proposal', self.student.test_cfg.rpn)
proposal_list = self.student.rpn_head.get_bboxes(*rpn_outs, img_metas, cfg=proposal_cfg)
losses.update(rpn_losses)
with torch.no_grad():
neck_mask_batch = self.get_roi_mask(rpn_outs[0], img_metas, gt_bboxes, phi=0.5)
for item_loc in self.distill_cfg:
feature_level = item_loc.feature_level
f_s_l = f_s[feature_level]
f_t_l = f_t[feature_level]
mask = neck_mask_batch[feature_level]
mask = mask.unsqueeze(1).repeat(1, f_s_l.size(1), 1, 1)
losses['{}_ratio'.format(feature_level)] = mask.sum() / mask.numel()
if 'n' in loss_name:
mask = 1 - mask
losses[loss_name] = self.distill_losses[loss_name](f_s_l, f_t_l, mask)
roi_losses = self.student.roi_head.forward_train(f_s, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, **kwargs):
return self.student.simple_test(img, img_metas, **kwargs)
def aug_test(self, imgs, img_metas, **kwargs):
return self.student.aug_test(img, img_metas, **kwargs)
def extract_feat(self, imgs):
"""Extract features from images."""
return self.student.extract_feat(imgs)