Skip to content
Snippets Groups Projects
Commit 420c88ad authored by Guo-Hua Wang's avatar Guo-Hua Wang
Browse files

fix bug

parent 177f8514
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@ from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
from mmdet.distillation import build_distiller
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
......@@ -155,10 +155,20 @@ def main():
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
distiller_cfg = cfg.get('distiller',None)
if distiller_cfg is None:
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
else:
teacher_cfg = Config.fromfile(cfg.teacher_cfg)
student_cfg = Config.fromfile(cfg.student_cfg)
model = build_distiller(cfg.distiller,teacher_cfg,student_cfg,
train_cfg=student_cfg.get('train_cfg'),
test_cfg=student_cfg.get('test_cfg'))
model.init_weights()
datasets = [build_dataset(cfg.data.train)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment