Skip to content
Snippets Groups Projects
Commit 1065c679 authored by wanggh's avatar wanggh
Browse files

add DeFeat

parent 91aa1b4e
No related branches found
No related tags found
1 merge request!5add DeFeat
_base_ = [
'../../_base_/models/faster_rcnn_r50_fpn.py',
'../../_base_/datasets/voc0712.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
pos_w=1
neg_w=2
distiller = dict(
type='DeFeat',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/faster_rcnn/faster_rcnn_r152_fpn_1x_voc0712.pth',
load_teacher_part = 'neck_head',
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_p_0',
student_channels = 256,
teacher_channels = 256,
weight = pos_w,
),
dict(type='MSELoss',
name='loss_n_0',
student_channels = 256,
teacher_channels = 256,
weight = neg_w,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_p_1',
student_channels = 256,
teacher_channels = 256,
weight = pos_w,
),
dict(type='MSELoss',
name='loss_n_1',
student_channels = 256,
teacher_channels = 256,
weight = neg_w,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_p_2',
student_channels = 256,
teacher_channels = 256,
weight = pos_w,
),
dict(type='MSELoss',
name='loss_n_2',
student_channels = 256,
teacher_channels = 256,
weight = neg_w,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_p_3',
student_channels = 256,
teacher_channels = 256,
weight = pos_w,
),
dict(type='MSELoss',
name='loss_n_3',
student_channels = 256,
teacher_channels = 256,
weight = neg_w,
)
]
),
]
)
student_cfg = 'configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py'
teacher_cfg = 'configs/pascal_voc/faster_rcnn_r152_fpn_1x_voc0712.py'
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
# actual epoch = 3 * 3 = 9
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=200,
warmup_ratio=0.001,
step=[3])
# runtime settings
runner = dict(
type='EpochBasedRunner', max_epochs=4) # actual epoch = 4 * 3 = 12
\ No newline at end of file
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmdet.models.detectors.base import BaseDetector
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
from ..builder import DISTILLER,build_distill_loss
from collections import OrderedDict
@DISTILLER.register_module()
class DeFeat(BaseDetector):
"""DeFeat distiller for detectors.
It typically consists of teacher_model and student_model.
"""
def __init__(self,
teacher_cfg,
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
load_teacher_part=None):
super(DeFeat, self).__init__()
self.teacher = build_detector(teacher_cfg.model,
train_cfg=teacher_cfg.get('train_cfg'),
test_cfg=teacher_cfg.get('test_cfg'))
self.init_weights_teacher(teacher_pretrained)
self.teacher.eval()
self.student= build_detector(student_cfg.model,
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
self.init_weights_student(load_teacher_part, teacher_pretrained)
self.distill_losses = nn.ModuleDict()
self.distill_cfg = distill_cfg
for item_loc in distill_cfg:
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
def base_parameters(self):
return nn.ModuleList([self.student, self.distill_losses])
def discriminator_parameters(self):
return self.discriminator
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self.student, 'neck') and self.student.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
def init_weights_teacher(self, path=None):
"""Load the pretrained model in teacher detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def init_weights_student(self, load_teacher_part, teacher_pretrained):
self.student.init_weights()
if load_teacher_part:
assert load_teacher_part in ['neck', 'head', 'neck_head']
def check_key(key, load_teacher_part):
if 'neck' in key and 'neck' in load_teacher_part:
return True
elif 'head' in key and 'head' in load_teacher_part:
return True
else:
return False
t_checkpoint = _load_checkpoint(teacher_pretrained, map_location='cpu')
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if check_key(name, load_teacher_part):
all_name.append((name, v))
state_dict = OrderedDict(all_name)
load_state_dict(self.student, state_dict)
def _map_roi_levels(self, rois, num_levels):
scale = torch.sqrt(
(rois[:, 2] - rois[:, 0] + 1) * (rois[:, 3] - rois[:, 1] + 1))
target_lvls = torch.floor(torch.log2(scale / 56 + 1e-6))
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
def get_gt_mask(self, cls_scores, img_metas, gt_bboxes):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
featmap_strides = self.student.rpn_head.anchor_generator.strides
if isinstance(featmap_strides[0], tuple):
featmap_strides = [strides[0] for strides in featmap_strides]
imit_range = [0, 0, 0, 0, 0]
with torch.no_grad():
mask_batch = []
for batch in range(len(gt_bboxes)):
mask_level = []
target_lvls = self._map_roi_levels(gt_bboxes[batch], len(featmap_sizes))
for level in range(len(featmap_sizes)):
gt_level = gt_bboxes[batch][target_lvls==level] # gt_bboxes: BatchsizexNpointx4coordinate
h, w = featmap_sizes[level][0], featmap_sizes[level][1]
mask_per_img = torch.zeros([h, w], dtype=torch.float).cuda()
for ins in range(gt_level.shape[0]):
gt_level_map = gt_level[ins] / featmap_strides[level]
lx = max(int(gt_level_map[0]) - imit_range[level], 0)
rx = min(int(gt_level_map[2]) + imit_range[level], w)
ly = max(int(gt_level_map[1]) - imit_range[level], 0)
ry = min(int(gt_level_map[3]) + imit_range[level], h)
if (lx == rx) or (ly == ry):
mask_per_img[ly, lx] += 1
else:
mask_per_img[ly:ry, lx:rx] += 1
mask_per_img = (mask_per_img > 0).float()
mask_level.append(mask_per_img)
mask_batch.append(mask_level)
mask_batch_level = []
for level in range(len(mask_batch[0])):
tmp = []
for batch in range(len(mask_batch)):
tmp.append(mask_batch[batch][level])
mask_batch_level.append(torch.stack(tmp, dim=0))
return mask_batch_level
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
"""
losses = dict()
with torch.no_grad():
self.teacher.eval()
f_t = self.teacher.backbone(img)
f_t = self.teacher.neck(f_t)
f_s = self.student.backbone(img)
f_s = self.student.neck(f_s)
rpn_outs = self.student.rpn_head(f_s)
loss_inputs = rpn_outs + (gt_bboxes, img_metas)
rpn_losses = self.student.rpn_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
proposal_cfg = self.student.train_cfg.get('rpn_proposal', self.student.test_cfg.rpn)
proposal_list = self.student.rpn_head.get_bboxes(*rpn_outs, img_metas, cfg=proposal_cfg)
losses.update(rpn_losses)
neck_mask_batch = self.get_gt_mask(rpn_outs[0], img_metas, gt_bboxes)
for item_loc in self.distill_cfg:
feature_level = item_loc.feature_level
f_s_l = f_s[feature_level]
f_t_l = f_t[feature_level]
mask = neck_mask_batch[feature_level]
mask = mask.unsqueeze(1).repeat(1, f_s_l.size(1), 1, 1)
for item_loss in item_loc.methods:
loss_name = item_loss.name
if 'n' in loss_name:
mask = 1 - mask
losses[loss_name] = self.distill_losses[loss_name](f_s_l, f_t_l, mask)
roi_losses = self.student.roi_head.forward_train(f_s, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, **kwargs):
return self.student.simple_test(img, img_metas, **kwargs)
def aug_test(self, imgs, img_metas, **kwargs):
return self.student.aug_test(img, img_metas, **kwargs)
def extract_feat(self, imgs):
"""Extract features from images."""
return self.student.extract_feat(imgs)
\ No newline at end of file
......@@ -4,6 +4,7 @@ from .fpn_distiller import FPNDistiller
from .head_distiller import HeadDistiller
from .feature_mimicking import FeatureMimicking
from .FGFI import FGFI
from .DeFeat import DeFeat
__all__ = [
'DetectionDistiller',
......@@ -12,4 +13,5 @@ __all__ = [
'HeadDistiller',
'FeatureMimicking',
'FGFI',
'DeFeat',
]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment