From 58adc05b8afa1f475e84a92f79dff1cefdb4d2b6 Mon Sep 17 00:00:00 2001 From: Guo-Hua Wang <wangguohua_key@163.com> Date: Mon, 29 Nov 2021 20:28:28 +0800 Subject: [PATCH] fix bug --- mmdet/apis/train.py | 3 ++- mmdet/models/necks/cbnet_fpn.py | 26 ++++++++++++++++++++++++++ tools/dist_fgd_train.sh | 9 +++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) create mode 100755 tools/dist_fgd_train.sh diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 4b1dee03..6465ab32 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -80,7 +80,8 @@ def train_detector(model, if distiller_cfg is None: optimizer = build_optimizer(model, cfg.optimizer) else: - optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer) + #optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer) + optimizer = build_optimizer(model.base_parameters(), cfg.optimizer) # use apex fp16 optimizer if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook": diff --git a/mmdet/models/necks/cbnet_fpn.py b/mmdet/models/necks/cbnet_fpn.py index e69de29b..d7b395fd 100644 --- a/mmdet/models/necks/cbnet_fpn.py +++ b/mmdet/models/necks/cbnet_fpn.py @@ -0,0 +1,26 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import xavier_init + +from ..builder import NECKS +from .fpn import FPN + +@NECKS.register_module() +class CBFPN(FPN): + ''' + FPN with weight sharing + which support mutliple outputs from cbnet + ''' + def forward(self, inputs): + if not isinstance(inputs[0], (list, tuple)): + inputs = [inputs] + + if self.training: + outs = [] + for x in inputs: + out = super().forward(x) + outs.append(out) + return outs + else: + out = super().forward(inputs[-1]) + return out \ No newline at end of file diff --git a/tools/dist_fgd_train.sh b/tools/dist_fgd_train.sh new file mode 100755 index 00000000..97a9edcf --- /dev/null +++ b/tools/dist_fgd_train.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +PORT=${PORT:-29500} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/fgd_train.py $CONFIG --launcher pytorch ${@:3} -- GitLab