diff --git a/configs/distillers/mimic_backbone/mb_tH_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py b/configs/distillers/mimic_backbone/mb_tH_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..bda44e5a1df554f52971cb09c829eb0606234e04 --- /dev/null +++ b/configs/distillers/mimic_backbone/mb_tH_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,55 @@ +_base_ = [ + '../../_base_/models/faster_rcnn_r50_fpn.py', + '../../_base_/datasets/coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +find_unused_parameters=True +weight=0 +distiller = dict( + type='BackboneDistiller', + teacher_pretrained = '/data/wanggh/project/pytorch/Swin-Transformer-Object-Detection/work_dirs/faster_rcnn_r152_fpn_1x_coco/latest.pth', + init_student = 'neck_head', + train_head = True, + distill_cfg = [ dict(feature_level = 0, + methods=[dict(type='MSELoss', + name='loss_mb_0', + student_channels = 256, + teacher_channels = 256, + weight = weight, + ) + ] + ), + dict(feature_level = 1, + methods=[dict(type='MSELoss', + name='loss_mb_1', + student_channels = 512, + teacher_channels = 512, + weight = weight, + ) + ] + ), + dict(feature_level = 2, + methods=[dict(type='MSELoss', + name='loss_mb_2', + student_channels = 1024, + teacher_channels = 1024, + weight = weight, + ) + ] + ), + dict(feature_level = 3, + methods=[dict(type='MSELoss', + name='loss_mb_3', + student_channels = 2048, + teacher_channels = 2048, + weight = weight, + ) + ] + ), + ] + ) + +student_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' +teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py' diff --git a/mmdet/distillation/distillers/backbone_distiller.py b/mmdet/distillation/distillers/backbone_distiller.py index f6ff660af56f04755235702af1fe2c8569b66f41..673d00aac0ec6f426bf19f4eb9ebc72ce24890f8 100644 --- a/mmdet/distillation/distillers/backbone_distiller.py +++ b/mmdet/distillation/distillers/backbone_distiller.py @@ -32,9 +32,10 @@ class BackboneDistiller(BaseDetector): self.teacher.eval() - self.student= build_detector(student_cfg.model, + self.student = build_detector(student_cfg.model, train_cfg=student_cfg.get('train_cfg'), test_cfg=student_cfg.get('test_cfg')) + self.student.init_weights() if init_student: assert init_student in ['neck', 'head', 'neck_head'] def check_key(key, init_student):