Skip to content
Snippets Groups Projects
Commit f17f4ee4 authored by wanggh's avatar wanggh
Browse files

Merge branch 'master' of git.nju.edu.cn:wanggh/Swin-Transformer-Object-Detection

parents ba760f89 ffe0c5c6
No related branches found
No related tags found
No related merge requests found
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='HeadDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_tiny_patch4_window7.pth',
init_student = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 96,
teacher_channels = 96,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 192,
teacher_channels = 192,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 384,
teacher_channels = 384,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 768,
teacher_channels = 768,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
from .detection_distiller import DetectionDistiller
from .backbone_distiller import BackboneDistiller
from .fpn_distiller import FPNDistiller
from .head_distiller import HeadDistiller
__all__ = [
'DetectionDistiller',
'BackboneDistiller',
'FPNDistiller'
'FPNDistiller',
'HeadDistiller'
]
\ No newline at end of file
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmdet.models.detectors.base import BaseDetector
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
from ..builder import DISTILLER,build_distill_loss
from collections import OrderedDict
@DISTILLER.register_module()
class HeadDistiller(BaseDetector):
"""Head distiller for detectors.
It typically consists of teacher_model and student_model.
"""
def __init__(self,
teacher_cfg,
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
init_student=False):
super(HeadDistiller, self).__init__()
self.teacher = build_detector(teacher_cfg.model,
train_cfg=teacher_cfg.get('train_cfg'),
test_cfg=teacher_cfg.get('test_cfg'))
self.init_weights_teacher(teacher_pretrained)
self.teacher.eval()
self.student= build_detector(student_cfg.model,
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
if init_student:
t_checkpoint = _load_checkpoint(teacher_pretrained)
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if name.startswith("backbone."):
continue
else:
all_name.append((name, v))
state_dict = OrderedDict(all_name)
load_state_dict(self.student, state_dict)
self.distill_losses = nn.ModuleDict()
self.distill_cfg = distill_cfg
for item_loc in distill_cfg:
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
def base_parameters(self):
return nn.ModuleList([self.student, self.distill_losses])
def discriminator_parameters(self):
return self.discriminator
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self.student, 'neck') and self.student.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))
def init_weights_teacher(self, path=None):
"""Load the pretrained model in teacher detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
"""
with torch.no_grad():
self.teacher.eval()
f_t = self.teacher.extract_feat(img)
proposal_cfg = self.teacher.train_cfg.get('rpn_proposal',
self.teacher.test_cfg.rpn)
t_rpn_losses, t_proposal_list = self.teacher.rpn_head.forward_train(
f_t,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
f_s = self.student.extract_feat(img)
losses = dict()
proposal_cfg = self.student.train_cfg.get('rpn_proposal',
self.student.test_cfg.rpn)
s_rpn_losses, s_proposal_list = self.student.rpn_head.forward_train(
f_s,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(s_rpn_losses)
s_roi_losses = self.student.roi_head.forward_train(f_s, img_metas, s_proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(s_roi_losses)
s_roi_with_pt_losses = self.student.roi_head.forward_train(f_s, img_metas, t_proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
for key in s_roi_with_pt_losses:
losses["s_w_tp_{}".format(key)] = s_roi_with_pt_losses[key]
# losses['s_roi_w_pt_cls'] = s_roi_with_pt_losses['loss_cls']
# losses['s_roi_w_pt_bbox'] = s_roi_with_pt_losses['loss_bbox']
# losses['s_roi_w_pt_mask'] = s_roi_with_pt_losses['loss_mask']
t_roi_with_ps_losses = self.teacher.roi_head.forward_train(f_t, img_metas, s_proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
for key in t_roi_with_ps_losses:
losses["t_w_sp_{}".format(key)] = t_roi_with_ps_losses[key]
# losses['t_roi_w_ps_cls'] = t_roi_with_ps_losses['loss_cls']
# losses['t_roi_w_ps_bbox'] = t_roi_with_ps_losses['loss_bbox']
# losses['t_roi_w_ps_mask'] = t_roi_with_ps_losses['loss_mask']
return losses
def simple_test(self, img, img_metas, **kwargs):
return self.student.simple_test(img, img_metas, **kwargs)
def aug_test(self, imgs, img_metas, **kwargs):
return self.student.aug_test(img, img_metas, **kwargs)
def extract_feat(self, imgs):
"""Extract features from images."""
return self.student.extract_feat(imgs)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment