From a002b205393f44146da5cb7b8d686d1b8cb40709 Mon Sep 17 00:00:00 2001
From: wanggh <wangguohua_key@163.com>
Date: Fri, 19 Nov 2021 16:33:41 +0800
Subject: [PATCH] before merge

---
 ...4_window7_mstrain_480-800_adamw_1x_coco.py | 21 ++++++++++---------
 tools/gen_checkpoint.py                       | 10 ++++-----
 2 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py
index dd42cba7..a74ad36b 100644
--- a/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py
+++ b/configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py
@@ -66,15 +66,16 @@ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), wei
                                                  'relative_position_bias_table': dict(decay_mult=0.),
                                                  'norm': dict(decay_mult=0.)}))
 lr_config = dict(step=[8, 11])
-runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+runner = dict(type='EpochBasedRunner', max_epochs=12)
 
 # do not use mmdet version fp16
-fp16 = None
-optimizer_config = dict(
-    type="DistOptimizerHook",
-    update_interval=1,
-    grad_clip=None,
-    coalesce=True,
-    bucket_size_mb=-1,
-    use_fp16=True,
-)
+# fp16 = None
+# optimizer_config = dict(
+#     type="DistOptimizerHook",
+#     update_interval=1,
+#     grad_clip=None,
+#     coalesce=True,
+#     bucket_size_mb=-1,
+#     use_fp16=True,
+# )
diff --git a/tools/gen_checkpoint.py b/tools/gen_checkpoint.py
index 6eaa3ac9..f6108ee8 100644
--- a/tools/gen_checkpoint.py
+++ b/tools/gen_checkpoint.py
@@ -23,8 +23,8 @@ def get_sd(filename, return_sd=True):
 
 def merge(target, backbone, head):
     tsd = target['state_dict']
-    bsd = target['state_dict']
-    hsd = target['state_dict']
+    bsd = backbone['state_dict']
+    hsd = head['state_dict']
     for key in tsd.keys():
         if 'backbone' in key:
             assert key in bsd
@@ -41,11 +41,11 @@ def main():
     backbone = get_sd(args.backbone, return_sd=False)
     head = get_sd(args.head, return_sd=False)
 
-    target = backbone.copy()
-    #target = head.copy()
+    #target = backbone.copy()
+    target = head.copy()
 
     target = merge(target, backbone, head)
-    os.makedirs(os.path.basename(args.out), exist_ok=True)
+    #os.makedirs(os.path.basename(args.out), exist_ok=True)
     torch.save(target, args.out)
     print("saved checkpoint in {}".format(args.out))
     
-- 
GitLab