diff --git a/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py b/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..2676c37634957f9d82c75925cd05ea5f2486e55b --- /dev/null +++ b/configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,55 @@ +_base_ = [ + '../../_base_/models/faster_rcnn_r50_fpn.py', + '../../_base_/datasets/coco_detection.py', + '../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py' +] + +# model settings +find_unused_parameters=True +weight=1 +distiller = dict( + type='BackboneDistiller', + teacher_pretrained = '/data/wanggh/project/pytorch/Swin-Transformer-Object-Detection/work_dirs/faster_rcnn_r152_fpn_1x_coco/latest.pth', + init_student = 'neck_head', + train_head = False, + distill_cfg = [ dict(feature_level = 0, + methods=[dict(type='MSELoss', + name='loss_mb_0', + student_channels = 256, + teacher_channels = 256, + weight = weight, + ) + ] + ), + dict(feature_level = 1, + methods=[dict(type='MSELoss', + name='loss_mb_1', + student_channels = 512, + teacher_channels = 512, + weight = weight, + ) + ] + ), + dict(feature_level = 2, + methods=[dict(type='MSELoss', + name='loss_mb_2', + student_channels = 1024, + teacher_channels = 1024, + weight = weight, + ) + ] + ), + dict(feature_level = 3, + methods=[dict(type='MSELoss', + name='loss_mb_3', + student_channels = 2048, + teacher_channels = 2048, + weight = weight, + ) + ] + ), + ] + ) + +student_cfg = 'configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py' +teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' diff --git a/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py b/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py index 104c815351dd3e37fd634c65c6b6d85f28afa0f2..b53c2b2389af73b80ba241223508b8be8827f5bb 100644 --- a/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py +++ b/configs/distillers/mimic_backbone/mb_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py @@ -10,7 +10,7 @@ weight=1 distiller = dict( type='BackboneDistiller', teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth', - init_student = True, + init_student = '', distill_cfg = [ dict(feature_level = 0, methods=[dict(type='MSELoss', name='loss_mb_0', diff --git a/configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py b/configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..58269670a8e99e6c79c8ab5e3d13448251ba8890 --- /dev/null +++ b/configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py @@ -0,0 +1,6 @@ +_base_ = './faster_rcnn_r50_fpn_1x_coco.py' +model = dict( + backbone=dict( + depth=152, + init_cfg=dict(type='Pretrained', + checkpoint='torchvision://resnet152'))) diff --git a/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py b/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py index 009bd93d06b3284c7b31f33f82d636f774e86b74..6517e7a1ee1fadb7f3569b47038a18e75fa8a327 100644 --- a/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py +++ b/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py @@ -2,4 +2,4 @@ _base_ = [ '../_base_/models/faster_rcnn_r50_fpn.py', '../_base_/datasets/coco_detection.py', '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' -] +] \ No newline at end of file diff --git a/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py b/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py index 86278d8a074ede03726f78c9112a97bb17bb7f57..062f3c1a838bfa7340aa6090cfc16e5bfdc85f9d 100644 --- a/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py +++ b/configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py @@ -12,6 +12,9 @@ model = dict( optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) #optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) optimizer_config = dict(grad_clip=None) +data = dict( + samples_per_gpu=4, + workers_per_gpu=4) # learning policy # actual epoch = 3 * 3 = 9 lr_config = dict( diff --git a/mmdet/distillation/distillers/backbone_distiller.py b/mmdet/distillation/distillers/backbone_distiller.py index c14b1c4eea75226ed24ed91897372fc24cfbf1ed..f6ff660af56f04755235702af1fe2c8569b66f41 100644 --- a/mmdet/distillation/distillers/backbone_distiller.py +++ b/mmdet/distillation/distillers/backbone_distiller.py @@ -38,9 +38,9 @@ class BackboneDistiller(BaseDetector): if init_student: assert init_student in ['neck', 'head', 'neck_head'] def check_key(key, init_student): - if key.startswith('neck.') and 'neck' in init_student: + if 'neck' in key and 'neck' in init_student: return True - elif key.startswith('head.') and 'head' in init_student: + elif 'head' in key and 'head' in init_student: return True else: return False diff --git a/submit_work.sh b/submit_work.sh index 9e35e75f75813a0d128b28013f1dd267cb09db28..4bfde547f981c45429ac1f90d3ae87d185204a71 100644 --- a/submit_work.sh +++ b/submit_work.sh @@ -11,5 +11,6 @@ while [[ $num > 0 ]]; do done sleep 2 # when $pid finished, run these -PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4 -#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH2_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4 \ No newline at end of file +#PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4 +#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py 4 +PORT=29502 tools/dist_train.sh configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py 8 \ No newline at end of file diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py index 59f138dc2db645d5ce0a66b92a804872843cc315..b3ba6d6441f9f28c6a16510471a12fd7d201ed09 100644 --- a/tools/gen_checkpoint.py +++ b/tools/gen_checkpoint.py @@ -8,9 +8,9 @@ def parse_args(): parser = argparse.ArgumentParser( description='generate model') parser.add_argument('--backbone', help='the backbone checkpoint file') - parser.add_argument('--backbone-neck', help='the backbone-neck checkpoint file') + parser.add_argument('--neck', help='the neck checkpoint file') parser.add_argument('--head', help='the head checkpoint file') - parser.add_argument('--new-backbone', help='the trained checkpoint file') + parser.add_argument('--distill', help='the distilled model checkpoint file') parser.add_argument('--out', help='output result file in pickle format') args = parser.parse_args() return args @@ -23,87 +23,47 @@ def get_sd(filename, return_sd=True): else: return ck -def merge(backbone, head): +def merge(backbone, neck, head): target = dict() target['state_dict'] = dict() tsd = target['state_dict'] bsd = backbone['state_dict'] + nsd = backbone['state_dict'] hsd = head['state_dict'] for key in bsd.keys(): if 'backbone' in key: tsd[key] = bsd[key] - for key in hsd.keys(): - if 'backbone' not in key: - tsd[key] = hsd[key] - return target - -def merge_bn_h(backbone, head): - target = dict() - target['state_dict'] = dict() - tsd = target['state_dict'] - bsd = backbone['state_dict'] - hsd = head['state_dict'] - - for key in bsd.keys(): - if 'backbone' in key or 'neck' in key: - tsd[key] = bsd[key] - else: - assert 'head' in key + for key in nsd.keys(): + if 'neck' in key: + tsd[key] = nsd[key] for key in hsd.keys(): if 'head' in key: tsd[key] = hsd[key] return target -def gen_backbone(backbone, new_backbone): - target = backbone.copy() - tsd = target['model'] - nbsd = new_backbone['state_dict'] - for key in tsd.keys(): - nk = 'backbone.{}'.format(key) - if nk not in nbsd: - print("{} not find".format(key)) - continue - tsd[key] = nbsd[nk] - return target - -def gen_imagenet_h(backbone, head): +def gen_student(distill): target = dict() target['state_dict'] = dict() tsd = target['state_dict'] - bsd = backbone['model'] - hsd = head['state_dict'] - for key in hsd.keys(): - if 'backbone' not in key: - tsd[key] = hsd[key] - else: - bkey = key[9:] - if bkey not in bsd: - print("{} not load".format(key)) - continue - tsd[key] = bsd[bkey] + distill_sd = distill['state_dict'] + for key in distill_sd.keys(): + if key.startswith('student.'): + tsd[key[8:]] = distill_sd[key] return target + def main(): args = parse_args() print("generate checkpoint") - if args.backbone and args.head: + if args.distill: + distill = get_sd(args.distill, return_sd=False) + target = gen_student(distill) + else: backbone = get_sd(args.backbone, return_sd=False) + neck = get_sd(args.neck, return_sd=False) head = get_sd(args.head, return_sd=False) - target = merge(backbone, head) - elif args.backbone_neck and args.head: - backbone = get_sd(args.backbone_neck, return_sd=False) - head = get_sd(args.head, return_sd=False) - print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head)) - target = merge_bn_h(backbone, head) - elif args.new_backbone and args.head: - backbone = get_sd(args.new_backbone, return_sd=False) - head = get_sd(args.head, return_sd=False) - target = gen_imagenet_h(backbone, head) - elif args.new_backbone: - backbone = get_sd(args.backbone, return_sd=False) - nb = get_sd(args.new_backbone, return_sd=False) - target = gen_backbone(backbone, nb) + target = merge(backbone, neck, head) #os.makedirs(os.path.basename(args.out), exist_ok=True) torch.save(target, args.out) print("saved checkpoint in {}".format(args.out))