Newer
Older
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmdet.models.detectors.base import BaseDetector
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
from ..builder import DISTILLER,build_distill_loss
from collections import OrderedDict
@DISTILLER.register_module()
It typically consists of teacher_model and student_model.
"""
def __init__(self,
teacher_cfg,
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
init_student=False):
self.teacher = build_detector(teacher_cfg.model,
train_cfg=teacher_cfg.get('train_cfg'),
test_cfg=teacher_cfg.get('test_cfg'))
self.init_weights_teacher(teacher_pretrained)
self.teacher.eval()
self.student= build_detector(student_cfg.model,
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
t_checkpoint = _load_checkpoint(teacher_pretrained, map_location='cpu')
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if name.startswith("backbone."):
continue
else:
all_name.append((name, v))
state_dict = OrderedDict(all_name)
load_state_dict(self.student, state_dict)
self.distill_losses = nn.ModuleDict()
self.distill_cfg = distill_cfg
student_modules = dict(self.student.named_modules())
teacher_modules = dict(self.teacher.named_modules())
def regitster_hooks(student_module,teacher_module):
def hook_teacher_forward(module, input, output):
self.register_buffer(teacher_module,output)
def hook_student_forward(module, input, output):
self.register_buffer( student_module,output )
return hook_teacher_forward,hook_student_forward
for item_loc in distill_cfg:
student_module = 'student_' + item_loc.student_module.replace('.','_')
teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')
self.register_buffer(student_module,None)
self.register_buffer(teacher_module,None)
hook_teacher_forward,hook_student_forward = regitster_hooks(student_module ,teacher_module )
teacher_modules[item_loc.teacher_module].register_forward_hook(hook_teacher_forward)
student_modules[item_loc.student_module].register_forward_hook(hook_student_forward)
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
def base_parameters(self):
return nn.ModuleList([self.student,self.distill_losses])
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def discriminator_parameters(self):
return self.discriminator
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self.student, 'neck') and self.student.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
def init_weights_teacher(self, path=None):
"""Load the pretrained model in teacher detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def forward_train(self, img, img_metas, **kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
"""
with torch.no_grad():
self.teacher.eval()
loss = self.teacher.extract_feat(img)
student_loss = self.student.forward_train(img, img_metas, **kwargs)
gt_info = dict(gt_bboxes = kwargs['gt_bboxes'], gt_labels = kwargs['gt_labels'])
buffer_dict = dict(self.named_buffers())
for item_loc in self.distill_cfg:
student_module = 'student_' + item_loc.student_module.replace('.','_')
teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')
student_feat = buffer_dict[student_module]
teacher_feat = buffer_dict[teacher_module]
for item_loss in item_loc.methods:
loss_name = item_loss.name
student_loss[ loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat,gt_info['gt_bboxes'], img_metas)
return student_loss
def simple_test(self, img, img_metas, **kwargs):
return self.student.simple_test(img, img_metas, **kwargs)
def aug_test(self, imgs, img_metas, **kwargs):
return self.student.aug_test(img, img_metas, **kwargs)
def extract_feat(self, imgs):
"""Extract features from images."""
return self.student.extract_feat(imgs)