Skip to content
Snippets Groups Projects
Commit e0480293 authored by Guo-Hua Wang's avatar Guo-Hua Wang
Browse files

opt code

parent 9e4746ea
No related branches found
No related tags found
No related merge requests found
......@@ -66,15 +66,16 @@ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), wei
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[8, 11])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
runner = dict(type='EpochBasedRunner', max_epochs=12)
# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
use_fp16=True,
)
# fp16 = None
# optimizer_config = dict(
# type="DistOptimizerHook",
# update_interval=1,
# grad_clip=None,
# coalesce=True,
# bucket_size_mb=-1,
# use_fp16=True,
# )
......@@ -8,6 +8,7 @@ def parse_args():
parser = argparse.ArgumentParser(
description='generate model')
parser.add_argument('--backbone', help='the backbone checkpoint file')
parser.add_argument('--backbone-neck', help='the backbone-neck checkpoint file')
parser.add_argument('--head', help='the head checkpoint file')
parser.add_argument('--new-backbone', help='the trained checkpoint file')
parser.add_argument('--out', help='output result file in pickle format')
......@@ -37,6 +38,23 @@ def merge(backbone, head):
tsd[key] = hsd[key]
return target
def merge_bn_h(backbone, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['state_dict']
hsd = head['state_dict']
for key in bsd.keys():
if 'backbone' in key or 'neck' in key:
tsd[key] = bsd[key]
else:
assert 'head' in key
for key in hsd.keys():
if 'head' in key:
tsd[key] = hsd[key]
return target
def gen_backbone(backbone, new_backbone):
target = backbone.copy()
tsd = target['model']
......@@ -53,11 +71,17 @@ def main():
args = parse_args()
print("generate checkpoint")
backbone = get_sd(args.backbone, return_sd=False)
if args.head:
if args.backbone and args.head:
backbone = get_sd(args.backbone, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = merge(backbone, head)
elif args.backbone_neck and args.head:
backbone = get_sd(args.backbone_neck, return_sd=False)
head = get_sd(args.head, return_sd=False)
print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head))
target = merge_bn_h(backbone, head)
elif args.new_backbone:
backbone = get_sd(args.backbone, return_sd=False)
nb = get_sd(args.new_backbone, return_sd=False)
target = gen_backbone(backbone, nb)
#os.makedirs(os.path.basename(args.out), exist_ok=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment