Skip to content
Snippets Groups Projects
Commit 7c7132fa authored by Guo-Hua Wang's avatar Guo-Hua Wang
Browse files

add gen ck

parent 6a979e21
No related branches found
No related tags found
No related merge requests found
Pipeline #9814 failed
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
data_root = '/opt/Dataset/COCO2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
......
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
data_root = '/opt/Dataset/COCO2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
......
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
data_root = '/opt/Dataset/COCO2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
......
import argparse
import os
import warnings
import torch
def parse_args():
parser = argparse.ArgumentParser(
description='generate model')
parser.add_argument('--backbone', help='the backbone checkpoint file')
parser.add_argument('--head', help='the head checkpoint file')
parser.add_argument('--out', help='output result file in pickle format')
args = parser.parse_args()
return args
def get_sd(filename, return_sd=True):
print("loading {}".format(filename))
ck = torch.load(filename)
if return_sd:
return ck['state_dict']
else:
return ck
def merge(target, backbone, head):
tsd = target['state_dict']
bsd = target['state_dict']
hsd = target['state_dict']
for key in tsd.keys():
if 'backbone' in key:
assert key in bsd
tsd[key] = bsd[key]
else:
assert key in hsd
tsd[key] = hsd[key]
return target
def main():
args = parse_args()
print("generate checkpoint")
backbone = get_sd(args.backbone, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = backbone.copy()
#target = head.copy()
target = merge(target, backbone, head)
os.makedirs(os.path.basename(args.out), exist_ok=True)
torch.save(target, args.out)
print("saved checkpoint in {}".format(args.out))
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment