Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmdet.models.detectors.base import BaseDetector
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
from ..builder import DISTILLER,build_distill_loss
from collections import OrderedDict
@DISTILLER.register_module()
class DeFeat(BaseDetector):
"""DeFeat distiller for detectors.
It typically consists of teacher_model and student_model.
"""
def __init__(self,
teacher_cfg,
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
load_teacher_part=None):
super(DeFeat, self).__init__()
self.teacher = build_detector(teacher_cfg.model,
train_cfg=teacher_cfg.get('train_cfg'),
test_cfg=teacher_cfg.get('test_cfg'))
self.init_weights_teacher(teacher_pretrained)
self.teacher.eval()
self.student= build_detector(student_cfg.model,
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
self.init_weights_student(load_teacher_part, teacher_pretrained)
self.distill_losses = nn.ModuleDict()
self.distill_cfg = distill_cfg
for item_loc in distill_cfg:
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
def base_parameters(self):
return nn.ModuleList([self.student, self.distill_losses])
def discriminator_parameters(self):
return self.discriminator
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self.student, 'neck') and self.student.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
def init_weights_teacher(self, path=None):
"""Load the pretrained model in teacher detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def init_weights_student(self, load_teacher_part, teacher_pretrained):
self.student.init_weights()
if load_teacher_part:
assert load_teacher_part in ['neck', 'head', 'neck_head']
def check_key(key, load_teacher_part):
if 'neck' in key and 'neck' in load_teacher_part:
return True
elif 'head' in key and 'head' in load_teacher_part:
return True
else:
return False
t_checkpoint = _load_checkpoint(teacher_pretrained, map_location='cpu')
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if check_key(name, load_teacher_part):
all_name.append((name, v))
state_dict = OrderedDict(all_name)
load_state_dict(self.student, state_dict)
def _map_roi_levels(self, rois, num_levels):
scale = torch.sqrt(
(rois[:, 2] - rois[:, 0] + 1) * (rois[:, 3] - rois[:, 1] + 1))
target_lvls = torch.floor(torch.log2(scale / 56 + 1e-6))
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
def get_gt_mask(self, cls_scores, img_metas, gt_bboxes):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
featmap_strides = self.student.rpn_head.anchor_generator.strides
if isinstance(featmap_strides[0], tuple):
featmap_strides = [strides[0] for strides in featmap_strides]
imit_range = [0, 0, 0, 0, 0]
with torch.no_grad():
mask_batch = []
for batch in range(len(gt_bboxes)):
mask_level = []
target_lvls = self._map_roi_levels(gt_bboxes[batch], len(featmap_sizes))
for level in range(len(featmap_sizes)):
gt_level = gt_bboxes[batch][target_lvls==level] # gt_bboxes: BatchsizexNpointx4coordinate
h, w = featmap_sizes[level][0], featmap_sizes[level][1]
mask_per_img = torch.zeros([h, w], dtype=torch.float).cuda()
for ins in range(gt_level.shape[0]):
gt_level_map = gt_level[ins] / featmap_strides[level]
lx = max(int(gt_level_map[0]) - imit_range[level], 0)
rx = min(int(gt_level_map[2]) + imit_range[level], w)
ly = max(int(gt_level_map[1]) - imit_range[level], 0)
ry = min(int(gt_level_map[3]) + imit_range[level], h)
if (lx == rx) or (ly == ry):
mask_per_img[ly, lx] += 1
else:
mask_per_img[ly:ry, lx:rx] += 1
mask_per_img = (mask_per_img > 0).float()
mask_level.append(mask_per_img)
mask_batch.append(mask_level)
mask_batch_level = []
for level in range(len(mask_batch[0])):
tmp = []
for batch in range(len(mask_batch)):
tmp.append(mask_batch[batch][level])
mask_batch_level.append(torch.stack(tmp, dim=0))
return mask_batch_level
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
"""
losses = dict()
with torch.no_grad():
self.teacher.eval()
f_t = self.teacher.backbone(img)
f_t = self.teacher.neck(f_t)
f_s = self.student.backbone(img)
f_s = self.student.neck(f_s)
rpn_outs = self.student.rpn_head(f_s)
loss_inputs = rpn_outs + (gt_bboxes, img_metas)
rpn_losses = self.student.rpn_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
proposal_cfg = self.student.train_cfg.get('rpn_proposal', self.student.test_cfg.rpn)
proposal_list = self.student.rpn_head.get_bboxes(*rpn_outs, img_metas, cfg=proposal_cfg)
losses.update(rpn_losses)
neck_mask_batch = self.get_gt_mask(rpn_outs[0], img_metas, gt_bboxes)
for item_loc in self.distill_cfg:
feature_level = item_loc.feature_level
f_s_l = f_s[feature_level]
f_t_l = f_t[feature_level]
mask = neck_mask_batch[feature_level]
mask = mask.unsqueeze(1).repeat(1, f_s_l.size(1), 1, 1)
for item_loss in item_loc.methods:
loss_name = item_loss.name
if 'n' in loss_name:
mask = 1 - mask
losses[loss_name] = self.distill_losses[loss_name](f_s_l, f_t_l, mask)
roi_losses = self.student.roi_head.forward_train(f_s, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, **kwargs):
return self.student.simple_test(img, img_metas, **kwargs)
def aug_test(self, imgs, img_metas, **kwargs):
return self.student.aug_test(img, img_metas, **kwargs)
def extract_feat(self, imgs):
"""Extract features from images."""
return self.student.extract_feat(imgs)