diff --git a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py index dd42cba7ca95c008218e966aca6becb2a2dabc8d..670773629503de74d095d15ea0e32cda53ab492e 100644 --- a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py +++ b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py @@ -66,15 +66,16 @@ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), wei 'relative_position_bias_table': dict(decay_mult=0.), 'norm': dict(decay_mult=0.)})) lr_config = dict(step=[8, 11]) -runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) - +#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) +runner = dict(type='EpochBasedRunner', max_epochs=12) + # do not use mmdet version fp16 -fp16 = None -optimizer_config = dict( - type="DistOptimizerHook", - update_interval=1, - grad_clip=None, - coalesce=True, - bucket_size_mb=-1, - use_fp16=True, -) +# fp16 = None +# optimizer_config = dict( +# type="DistOptimizerHook", +# update_interval=1, +# grad_clip=None, +# coalesce=True, +# bucket_size_mb=-1, +# use_fp16=True, +# ) diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py index 7510d9e13753604b1312bc499467a8e959aec56b..6e731b4f1a1b0704dc204c4369b5efa8ba1fd067 100644 --- a/tools/gen_checkpoint.py +++ b/tools/gen_checkpoint.py @@ -8,6 +8,7 @@ def parse_args(): parser = argparse.ArgumentParser( description='generate model') parser.add_argument('--backbone', help='the backbone checkpoint file') + parser.add_argument('--backbone-neck', help='the backbone-neck checkpoint file') parser.add_argument('--head', help='the head checkpoint file') parser.add_argument('--new-backbone', help='the trained checkpoint file') parser.add_argument('--out', help='output result file in pickle format') @@ -37,6 +38,23 @@ def merge(backbone, head): tsd[key] = hsd[key] return target +def merge_bn_h(backbone, head): + target = dict() + target['state_dict'] = dict() + tsd = target['state_dict'] + bsd = backbone['state_dict'] + hsd = head['state_dict'] + + for key in bsd.keys(): + if 'backbone' in key or 'neck' in key: + tsd[key] = bsd[key] + else: + assert 'head' in key + for key in hsd.keys(): + if 'head' in key: + tsd[key] = hsd[key] + return target + def gen_backbone(backbone, new_backbone): target = backbone.copy() tsd = target['model'] @@ -53,11 +71,17 @@ def main(): args = parse_args() print("generate checkpoint") - backbone = get_sd(args.backbone, return_sd=False) - if args.head: + if args.backbone and args.head: + backbone = get_sd(args.backbone, return_sd=False) head = get_sd(args.head, return_sd=False) target = merge(backbone, head) + elif args.backbone_neck and args.head: + backbone = get_sd(args.backbone_neck, return_sd=False) + head = get_sd(args.head, return_sd=False) + print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head)) + target = merge_bn_h(backbone, head) elif args.new_backbone: + backbone = get_sd(args.backbone, return_sd=False) nb = get_sd(args.new_backbone, return_sd=False) target = gen_backbone(backbone, nb) #os.makedirs(os.path.basename(args.out), exist_ok=True)