Skip to content
Snippets Groups Projects
Commit ac9aa52b authored by wanggh's avatar wanggh
Browse files

Merge branch 'master' into gpu3

parents 2c0a9f5a 6cec7a1b
No related branches found
No related tags found
1 merge request!2Gpu3
Showing
with 730 additions and 21 deletions
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='BackboneDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
init_student = 'neck_head',
train_head = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 96,
teacher_channels = 96,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 192,
teacher_channels = 192,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 384,
teacher_channels = 384,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 768,
teacher_channels = 768,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/mask_rcnn_swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/cascade_mask_rcnn_swin_base_patch4_window7.pth',
init_student = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/cascade_mask_rcnn_swin_small_patch4_window7.pth',
init_student = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/cascade_mask_rcnn_swin_tiny_patch4_window7.pth',
init_student = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
init_student = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/mask_rcnn_swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
init_student = 'head',
train_head = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/mask_rcnn_swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
init_student = '',
train_head = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/mask_rcnn_swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
_base_ = [
'../../_base_/models/mask_rcnn_swin_fpn.py',
'../../_base_/datasets/coco_instance.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='FPNDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
init_student = 'neck_head',
train_head = True,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py'
teacher_cfg = 'configs/swin/mask_rcnn_swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
#data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
\ No newline at end of file
......@@ -126,15 +126,16 @@ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), wei
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
......@@ -20,7 +20,8 @@ class BackboneDistiller(BaseDetector):
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
init_student=False):
init_student=None,
train_head=False):
super(BackboneDistiller, self).__init__()
......@@ -35,12 +36,19 @@ class BackboneDistiller(BaseDetector):
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
if init_student:
assert init_student in ['neck', 'head', 'neck_head']
def check_key(key, init_student):
if key.startswith('neck.') and 'neck' in init_student:
return True
elif key.startswith('head.') and 'head' in init_student:
return True
else:
return False
t_checkpoint = _load_checkpoint(teacher_pretrained)
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if name.startswith("backbone."):
continue
else:
if check_key(name, init_student):
all_name.append((name, v))
state_dict = OrderedDict(all_name)
......@@ -52,6 +60,7 @@ class BackboneDistiller(BaseDetector):
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
self.train_head = train_head
def base_parameters(self):
return nn.ModuleList([self.student, self.distill_losses])
......@@ -92,7 +101,15 @@ class BackboneDistiller(BaseDetector):
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def forward_train(self, img, img_metas, **kwargs):
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
......@@ -117,6 +134,25 @@ class BackboneDistiller(BaseDetector):
for item_loss in item_loc.methods:
loss_name = item_loss.name
losses[loss_name] = self.distill_losses[loss_name](f_s[feature_level], f_t[feature_level])
if self.train_head:
x = self.student.neck(f_s)
proposal_cfg = self.student.train_cfg.get('rpn_proposal',
self.student.test_cfg.rpn)
rpn_losses, proposal_list = self.student.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
roi_losses = self.student.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, **kwargs):
......
......@@ -20,7 +20,8 @@ class FPNDistiller(BaseDetector):
student_cfg,
distill_cfg=None,
teacher_pretrained=None,
init_student=False):
init_student=None,
train_head=False):
super(FPNDistiller, self).__init__()
......@@ -35,12 +36,19 @@ class FPNDistiller(BaseDetector):
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
if init_student:
assert init_student in ['neck', 'head', 'neck_head']
def check_key(key, init_student):
if key.startswith('neck.') and 'neck' in init_student:
return True
elif key.startswith('head.') and 'head' in init_student:
return True
else:
return False
t_checkpoint = _load_checkpoint(teacher_pretrained)
all_name = []
for name, v in t_checkpoint["state_dict"].items():
if name.startswith("backbone."):
continue
else:
if check_key(name, init_student):
all_name.append((name, v))
state_dict = OrderedDict(all_name)
......@@ -52,6 +60,7 @@ class FPNDistiller(BaseDetector):
for item_loss in item_loc.methods:
loss_name = item_loss.name
self.distill_losses[loss_name] = build_distill_loss(item_loss)
self.train_head = train_head
def base_parameters(self):
return nn.ModuleList([self.student, self.distill_losses])
......@@ -92,7 +101,15 @@ class FPNDistiller(BaseDetector):
checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')
def forward_train(self, img, img_metas, **kwargs):
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
......@@ -119,6 +136,26 @@ class FPNDistiller(BaseDetector):
for item_loss in item_loc.methods:
loss_name = item_loss.name
losses[loss_name] = self.distill_losses[loss_name](f_s[feature_level], f_t[feature_level])
if self.train_head:
x = f_s
proposal_cfg = self.student.train_cfg.get('rpn_proposal',
self.student.test_cfg.rpn)
rpn_losses, proposal_list = self.student.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
roi_losses = self.student.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, **kwargs):
......
#!/bin/bash
# run by: bash scripts/submit_work.sh $pid
# set pid first
pid=$1
echo wait for $pid
num=`ps -aux | grep $pid | grep python | wc -l`
while [[ $num > 0 ]]; do
#echo wait for $pid
sleep 1
num=`ps -aux | grep $pid | grep python | wc -l`
done
sleep 2
# when $pid finished, run these
PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH2_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment