Skip to content
Snippets Groups Projects
Commit e78febbb authored by wanggh's avatar wanggh
Browse files

Merge branch 'master' of git.nju.edu.cn:wanggh/Swin-Transformer-Object-Detection

parents 3653775f e0480293
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ def parse_args():
parser = argparse.ArgumentParser(
description='generate model')
parser.add_argument('--backbone', help='the backbone checkpoint file')
parser.add_argument('--backbone-neck', help='the backbone-neck checkpoint file')
parser.add_argument('--head', help='the head checkpoint file')
parser.add_argument('--new-backbone', help='the trained checkpoint file')
parser.add_argument('--out', help='output result file in pickle format')
......@@ -36,6 +37,23 @@ def merge(backbone, head):
tsd[key] = hsd[key]
return target
def merge_bn_h(backbone, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['state_dict']
hsd = head['state_dict']
for key in bsd.keys():
if 'backbone' in key or 'neck' in key:
tsd[key] = bsd[key]
else:
assert 'head' in key
for key in hsd.keys():
if 'head' in key:
tsd[key] = hsd[key]
return target
def gen_backbone(backbone, new_backbone):
target = backbone.copy()
tsd = target['model']
......@@ -48,15 +66,42 @@ def gen_backbone(backbone, new_backbone):
tsd[key] = nbsd[nk]
return target
def gen_imagenet_h(backbone, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['model']
hsd = head['state_dict']
for key in hsd.keys():
if 'backbone' not in key:
tsd[key] = hsd[key]
else:
bkey = key[9:]
if bkey not in bsd:
print("{} not load".format(key))
continue
tsd[key] = bsd[bkey]
return target
def main():
args = parse_args()
print("generate checkpoint")
backbone = get_sd(args.backbone, return_sd=False)
if args.head:
if args.backbone and args.head:
backbone = get_sd(args.backbone, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = merge(backbone, head)
elif args.backbone_neck and args.head:
backbone = get_sd(args.backbone_neck, return_sd=False)
head = get_sd(args.head, return_sd=False)
print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head))
target = merge_bn_h(backbone, head)
elif args.new_backbone and args.head:
backbone = get_sd(args.new_backbone, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = gen_imagenet_h(backbone, head)
elif args.new_backbone:
backbone = get_sd(args.backbone, return_sd=False)
nb = get_sd(args.new_backbone, return_sd=False)
target = gen_backbone(backbone, nb)
#os.makedirs(os.path.basename(args.out), exist_ok=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment