Skip to content
Snippets Groups Projects
Commit 03e63e5e authored by Guo-Hua Wang's avatar Guo-Hua Wang
Browse files

Merge branch 'master' of git.nju.edu.cn:wanggh/Swin-Transformer-Object-Detection

parents 12152d61 e78febbb
No related branches found
No related tags found
No related merge requests found
......@@ -68,7 +68,7 @@ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), wei
lr_config = dict(step=[8, 11])
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
# fp16 = None
# optimizer_config = dict(
......
......@@ -29,7 +29,6 @@ def merge(backbone, head):
tsd = target['state_dict']
bsd = backbone['state_dict']
hsd = head['state_dict']
for key in bsd.keys():
if 'backbone' in key:
tsd[key] = bsd[key]
......@@ -67,6 +66,23 @@ def gen_backbone(backbone, new_backbone):
tsd[key] = nbsd[nk]
return target
def gen_imagenet_h(backbone, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['model']
hsd = head['state_dict']
for key in hsd.keys():
if 'backbone' not in key:
tsd[key] = hsd[key]
else:
bkey = key[9:]
if bkey not in bsd:
print("{} not load".format(key))
continue
tsd[key] = bsd[bkey]
return target
def main():
args = parse_args()
print("generate checkpoint")
......@@ -80,6 +96,10 @@ def main():
head = get_sd(args.head, return_sd=False)
print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head))
target = merge_bn_h(backbone, head)
elif args.new_backbone and args.head:
backbone = get_sd(args.new_backbone, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = gen_imagenet_h(backbone, head)
elif args.new_backbone:
backbone = get_sd(args.backbone, return_sd=False)
nb = get_sd(args.new_backbone, return_sd=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment