Skip to content
Snippets Groups Projects
Commit 59d8aef4 authored by wanggh's avatar wanggh
Browse files

opt gen checkpoint

parent f51b534d
No related branches found
No related tags found
No related merge requests found
_base_ = [
'../../_base_/models/faster_rcnn_r50_fpn.py',
'../../_base_/datasets/coco_detection.py',
'../../_base_/schedules/schedule_1x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
weight=1
distiller = dict(
type='BackboneDistiller',
teacher_pretrained = '/data/wanggh/project/pytorch/Swin-Transformer-Object-Detection/work_dirs/faster_rcnn_r152_fpn_1x_coco/latest.pth',
init_student = 'neck_head',
train_head = False,
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
student_channels = 256,
teacher_channels = 256,
weight = weight,
)
]
),
dict(feature_level = 1,
methods=[dict(type='MSELoss',
name='loss_mb_1',
student_channels = 512,
teacher_channels = 512,
weight = weight,
)
]
),
dict(feature_level = 2,
methods=[dict(type='MSELoss',
name='loss_mb_2',
student_channels = 1024,
teacher_channels = 1024,
weight = weight,
)
]
),
dict(feature_level = 3,
methods=[dict(type='MSELoss',
name='loss_mb_3',
student_channels = 2048,
teacher_channels = 2048,
weight = weight,
)
]
),
]
)
student_cfg = 'configs/faster_rcnn/faster_rcnn_r152_fpn_1x_coco.py'
teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
......@@ -10,7 +10,7 @@ weight=1
distiller = dict(
type='BackboneDistiller',
teacher_pretrained = '/mnt/data3/wangguohua/model/mmdet/swin/mask_rcnn_swin_small_patch4_window7.pth',
init_student = True,
init_student = '',
distill_cfg = [ dict(feature_level = 0,
methods=[dict(type='MSELoss',
name='loss_mb_0',
......
_base_ = './faster_rcnn_r50_fpn_1x_coco.py'
model = dict(
backbone=dict(
depth=152,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet152')))
......@@ -2,4 +2,4 @@ _base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
]
\ No newline at end of file
......@@ -12,6 +12,9 @@ model = dict(
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
#optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
optimizer_config = dict(grad_clip=None)
data = dict(
samples_per_gpu=4,
workers_per_gpu=4)
# learning policy
# actual epoch = 3 * 3 = 9
lr_config = dict(
......
......@@ -38,9 +38,9 @@ class BackboneDistiller(BaseDetector):
if init_student:
assert init_student in ['neck', 'head', 'neck_head']
def check_key(key, init_student):
if key.startswith('neck.') and 'neck' in init_student:
if 'neck' in key and 'neck' in init_student:
return True
elif key.startswith('head.') and 'head' in init_student:
elif 'head' in key and 'head' in init_student:
return True
else:
return False
......
......@@ -11,5 +11,6 @@ while [[ $num > 0 ]]; do
done
sleep 2
# when $pid finished, run these
PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH2_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
\ No newline at end of file
#PORT=29504 CUDA_VISIBLE_DEVICES=0,1,2,3 tools/dist_train.sh configs/distillers/mimic_fpn/mfpn_trainH3_mask_rcnn_swinS_fpn_3x_distill_mask_rcnn_swinT_fpn_1x_coco.py 4
#PORT=29502 CUDA_VISIBLE_DEVICES=4,5,6,7 tools/dist_train.sh configs/pascal_voc/faster_rcnn_r101_fpn_1x_voc0712.py 4
PORT=29502 tools/dist_train.sh configs/distillers/mimic_backbone/mb_faster_rcnn_r152_fpn_1x_distill_faster_rcnn_r50_fpn_1x_coco.py 8
\ No newline at end of file
......@@ -8,9 +8,9 @@ def parse_args():
parser = argparse.ArgumentParser(
description='generate model')
parser.add_argument('--backbone', help='the backbone checkpoint file')
parser.add_argument('--backbone-neck', help='the backbone-neck checkpoint file')
parser.add_argument('--neck', help='the neck checkpoint file')
parser.add_argument('--head', help='the head checkpoint file')
parser.add_argument('--new-backbone', help='the trained checkpoint file')
parser.add_argument('--distill', help='the distilled model checkpoint file')
parser.add_argument('--out', help='output result file in pickle format')
args = parser.parse_args()
return args
......@@ -23,87 +23,47 @@ def get_sd(filename, return_sd=True):
else:
return ck
def merge(backbone, head):
def merge(backbone, neck, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['state_dict']
nsd = backbone['state_dict']
hsd = head['state_dict']
for key in bsd.keys():
if 'backbone' in key:
tsd[key] = bsd[key]
for key in hsd.keys():
if 'backbone' not in key:
tsd[key] = hsd[key]
return target
def merge_bn_h(backbone, head):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['state_dict']
hsd = head['state_dict']
for key in bsd.keys():
if 'backbone' in key or 'neck' in key:
tsd[key] = bsd[key]
else:
assert 'head' in key
for key in nsd.keys():
if 'neck' in key:
tsd[key] = nsd[key]
for key in hsd.keys():
if 'head' in key:
tsd[key] = hsd[key]
return target
def gen_backbone(backbone, new_backbone):
target = backbone.copy()
tsd = target['model']
nbsd = new_backbone['state_dict']
for key in tsd.keys():
nk = 'backbone.{}'.format(key)
if nk not in nbsd:
print("{} not find".format(key))
continue
tsd[key] = nbsd[nk]
return target
def gen_imagenet_h(backbone, head):
def gen_student(distill):
target = dict()
target['state_dict'] = dict()
tsd = target['state_dict']
bsd = backbone['model']
hsd = head['state_dict']
for key in hsd.keys():
if 'backbone' not in key:
tsd[key] = hsd[key]
else:
bkey = key[9:]
if bkey not in bsd:
print("{} not load".format(key))
continue
tsd[key] = bsd[bkey]
distill_sd = distill['state_dict']
for key in distill_sd.keys():
if key.startswith('student.'):
tsd[key[8:]] = distill_sd[key]
return target
def main():
args = parse_args()
print("generate checkpoint")
if args.backbone and args.head:
if args.distill:
distill = get_sd(args.distill, return_sd=False)
target = gen_student(distill)
else:
backbone = get_sd(args.backbone, return_sd=False)
neck = get_sd(args.neck, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = merge(backbone, head)
elif args.backbone_neck and args.head:
backbone = get_sd(args.backbone_neck, return_sd=False)
head = get_sd(args.head, return_sd=False)
print("backbone+neck:{} head:{}".format(args.backbone_neck, args.head))
target = merge_bn_h(backbone, head)
elif args.new_backbone and args.head:
backbone = get_sd(args.new_backbone, return_sd=False)
head = get_sd(args.head, return_sd=False)
target = gen_imagenet_h(backbone, head)
elif args.new_backbone:
backbone = get_sd(args.backbone, return_sd=False)
nb = get_sd(args.new_backbone, return_sd=False)
target = gen_backbone(backbone, nb)
target = merge(backbone, neck, head)
#os.makedirs(os.path.basename(args.out), exist_ok=True)
torch.save(target, args.out)
print("saved checkpoint in {}".format(args.out))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment