From 58adc05b8afa1f475e84a92f79dff1cefdb4d2b6 Mon Sep 17 00:00:00 2001
From: Guo-Hua Wang <wangguohua_key@163.com>
Date: Mon, 29 Nov 2021 20:28:28 +0800
Subject: [PATCH] fix bug

---
 mmdet/apis/train.py             |  3 ++-
 mmdet/models/necks/cbnet_fpn.py | 26 ++++++++++++++++++++++++++
 tools/dist_fgd_train.sh         |  9 +++++++++
 3 files changed, 37 insertions(+), 1 deletion(-)
 create mode 100755 tools/dist_fgd_train.sh

diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 4b1dee03..6465ab32 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -80,7 +80,8 @@ def train_detector(model,
     if distiller_cfg is None:
         optimizer = build_optimizer(model, cfg.optimizer)
     else:
-        optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer)
+        #optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer)
+        optimizer = build_optimizer(model.base_parameters(), cfg.optimizer)
 
     # use apex fp16 optimizer
     if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
diff --git a/mmdet/models/necks/cbnet_fpn.py b/mmdet/models/necks/cbnet_fpn.py
index e69de29b..d7b395fd 100644
--- a/mmdet/models/necks/cbnet_fpn.py
+++ b/mmdet/models/necks/cbnet_fpn.py
@@ -0,0 +1,26 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+
+from ..builder import NECKS
+from .fpn import FPN
+
+@NECKS.register_module()
+class CBFPN(FPN):
+    '''
+    FPN with weight sharing
+    which support mutliple outputs from cbnet
+    '''
+    def forward(self, inputs):
+        if not isinstance(inputs[0], (list, tuple)):
+            inputs = [inputs]
+            
+        if self.training:
+            outs = []
+            for x in inputs:
+                out = super().forward(x)
+                outs.append(out)
+            return outs
+        else:
+            out = super().forward(inputs[-1])
+            return out
\ No newline at end of file
diff --git a/tools/dist_fgd_train.sh b/tools/dist_fgd_train.sh
new file mode 100755
index 00000000..97a9edcf
--- /dev/null
+++ b/tools/dist_fgd_train.sh
@@ -0,0 +1,9 @@
+#!/usr/bin/env bash
+
+CONFIG=$1
+GPUS=$2
+PORT=${PORT:-29500}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+    $(dirname "$0")/fgd_train.py $CONFIG --launcher pytorch ${@:3}
-- 
GitLab