diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 7f2f1f95c0a8e7c9232f7aa490e8104f8e37c4f5..6465ab328e3bc559da33e1d55b1a58bca61c6b13 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -76,7 +76,12 @@ def train_detector(model, ] # build optimizer - optimizer = build_optimizer(model, cfg.optimizer) + distiller_cfg = cfg.get('distiller',None) + if distiller_cfg is None: + optimizer = build_optimizer(model, cfg.optimizer) + else: + #optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer) + optimizer = build_optimizer(model.base_parameters(), cfg.optimizer) # use apex fp16 optimizer if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":