From a002b205393f44146da5cb7b8d686d1b8cb40709 Mon Sep 17 00:00:00 2001 From: wanggh <wangguohua_key@163.com> Date: Fri, 19 Nov 2021 16:33:41 +0800 Subject: [PATCH] before merge --- ...4_window7_mstrain_480-800_adamw_1x_coco.py | 21 ++++++++++--------- tools/gen_checkpoint.py | 10 ++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py index dd42cba7..a74ad36b 100644 --- a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py +++ b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py @@ -66,15 +66,16 @@ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), wei 'relative_position_bias_table': dict(decay_mult=0.), 'norm': dict(decay_mult=0.)})) lr_config = dict(step=[8, 11]) -runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) +#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) +runner = dict(type='EpochBasedRunner', max_epochs=12) # do not use mmdet version fp16 -fp16 = None -optimizer_config = dict( - type="DistOptimizerHook", - update_interval=1, - grad_clip=None, - coalesce=True, - bucket_size_mb=-1, - use_fp16=True, -) +# fp16 = None +# optimizer_config = dict( +# type="DistOptimizerHook", +# update_interval=1, +# grad_clip=None, +# coalesce=True, +# bucket_size_mb=-1, +# use_fp16=True, +# ) diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py index 6eaa3ac9..f6108ee8 100644 --- a/tools/gen_checkpoint.py +++ b/tools/gen_checkpoint.py @@ -23,8 +23,8 @@ def get_sd(filename, return_sd=True): def merge(target, backbone, head): tsd = target['state_dict'] - bsd = target['state_dict'] - hsd = target['state_dict'] + bsd = backbone['state_dict'] + hsd = head['state_dict'] for key in tsd.keys(): if 'backbone' in key: assert key in bsd @@ -41,11 +41,11 @@ def main(): backbone = get_sd(args.backbone, return_sd=False) head = get_sd(args.head, return_sd=False) - target = backbone.copy() - #target = head.copy() + #target = backbone.copy() + target = head.copy() target = merge(target, backbone, head) - os.makedirs(os.path.basename(args.out), exist_ok=True) + #os.makedirs(os.path.basename(args.out), exist_ok=True) torch.save(target, args.out) print("saved checkpoint in {}".format(args.out)) -- GitLab