From efd02f86182c5f16531c231ed7e05e0952c78fb3 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 23 Nov 2018 20:49:53 +0800
Subject: [PATCH] refactoring for retinanet

---
 MODEL_ZOO.md                                  |  8 ++---
 configs/retinanet_r50_fpn_1x.py               |  8 ++---
 mmdet/core/loss/losses.py                     |  2 +-
 .../models/single_stage_heads/retina_head.py  | 36 ++++++++++++-------
 4 files changed, 32 insertions(+), 22 deletions(-)

diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md
index ed526d5e..e581f17f 100644
--- a/MODEL_ZOO.md
+++ b/MODEL_ZOO.md
@@ -71,13 +71,13 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
 | R-50-FPN | pytorch | Mask   | 1x      | 5.3      | 0.50                | 10.6           | 36.8   | 34.1    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_1x_20181010-e030a38f.pth) \| [result](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_1x_20181010_results.pkl.json) |
 | R-50-FPN | pytorch | Mask   | 2x      | 5.3      | 0.50                | 10.6           | 37.9   | 34.8    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_2x_20181010-5048cb03.pth) \| [result](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_2x_20181010_results.pkl.json) |
 
-### RetinaNet (coming soon)
+### RetinaNet
 
 | Backbone | Style   | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
 |:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
-| R-50-FPN | caffe   | 1x      |          |                     |                |        |          |
-| R-50-FPN | pytorch | 1x      |          |                     |                |        |          |
-| R-50-FPN | pytorch | 2x      |          |                     |                |        |          |
+| R-50-FPN | caffe   | 1x      | 5.8      | 0.459               | 9.4            |        | -        |
+| R-50-FPN | pytorch | 1x      | 6.2      | 0.487               | 9.1            |        |          |
+| R-50-FPN | pytorch | 2x      | 6.2      | 0.487               | 9.1            |        |          |
 
 
 ## Comparison with Detectron
diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py
index 64423bb3..079d6e53 100644
--- a/configs/retinanet_r50_fpn_1x.py
+++ b/configs/retinanet_r50_fpn_1x.py
@@ -22,16 +22,16 @@ model = dict(
         in_channels=256,
         stacked_convs=4,
         feat_channels=256,
+        octave_base_scale=4,
         scales_per_octave=3,
-        anchor_scale=4,
-        anchor_ratios=[1.0, 2.0, 0.5],
+        anchor_ratios=[0.5, 1.0, 2.0],
         anchor_strides=[8, 16, 32, 64, 128],
         target_means=[.0, .0, .0, .0],
         target_stds=[1.0, 1.0, 1.0, 1.0]))
 # training and testing settings
 train_cfg = dict(
     assigner=dict(
-        pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0.4, ignore_iof_thr=-1),
+        pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1),
     smoothl1_beta=0.11,
     gamma=2.0,
     alpha=0.25,
@@ -99,7 +99,7 @@ lr_config = dict(
 checkpoint_config = dict(interval=1)
 # yapf:disable
 log_config = dict(
-    interval=20,
+    interval=50,
     hooks=[
         dict(type='TextLoggerHook'),
         # dict(type='TensorboardLoggerHook')
diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py
index 14b49f5c..af1ceae7 100644
--- a/mmdet/core/loss/losses.py
+++ b/mmdet/core/loss/losses.py
@@ -31,7 +31,7 @@ def sigmoid_focal_loss(pred,
                        gamma=2.0,
                        alpha=0.25,
                        reduction='elementwise_mean'):
-    pred_sigmoid = pred.sigmoid()
+    pred_sigmoid = pred.sigmoid().detach()
     pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
     weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
     weight = weight * pt.pow(gamma)
diff --git a/mmdet/models/single_stage_heads/retina_head.py b/mmdet/models/single_stage_heads/retina_head.py
index 44ddaa6b..cc049239 100644
--- a/mmdet/models/single_stage_heads/retina_head.py
+++ b/mmdet/models/single_stage_heads/retina_head.py
@@ -23,6 +23,14 @@ class RetinaHead(nn.Module):
         stacked_convs (int): Number of convolutional layers added for cls and
             reg branch.
         feat_channels (int): Number of channels for the RPN feature map.
+        scales_per_octave (int): Number of anchor scales per octave.
+        octave_base_scale (int): Base octave scale. Anchor scales are computed
+            as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale`
+            and n is `scales_per_octave`.
+        anchor_ratios (Iterable): Anchor aspect ratios.
+        anchor_strides (Iterable): Anchor strides.
+        target_means (Iterable): Mean values of regression targets.
+        target_stds (Iterable): Std values of regression targets.
     """
 
     def __init__(self,
@@ -30,30 +38,32 @@ class RetinaHead(nn.Module):
                  num_classes,
                  stacked_convs=4,
                  feat_channels=256,
+                 octave_base_scale=4,
                  scales_per_octave=3,
-                 anchor_scale=4,
-                 anchor_ratios=[1.0, 2.0, 0.5],
+                 anchor_ratios=[0.5, 1.0, 2.0],
                  anchor_strides=[8, 16, 32, 64, 128],
-                 target_means=[.0, .0, .0, .0],
-                 target_stds=[1.0, 1.0, 1.0, 1.0]):
+                 anchor_base_sizes=None,
+                 target_means=(.0, .0, .0, .0),
+                 target_stds=(1.0, 1.0, 1.0, 1.0)):
         super(RetinaHead, self).__init__()
         self.in_channels = in_channels
         self.num_classes = num_classes
+        self.octave_base_scale = octave_base_scale
         self.scales_per_octave = scales_per_octave
-        self.anchor_scale = anchor_scale
-        self.anchor_strides = anchor_strides
         self.anchor_ratios = anchor_ratios
+        self.anchor_strides = anchor_strides
+        self.anchor_base_sizes = list(
+            anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
         self.target_means = target_means
         self.target_stds = target_stds
+
         self.anchor_generators = []
-        for anchor_stride in self.anchor_strides:
-            octave_scales = np.array([
-                2**(octave / float(scales_per_octave))
-                for octave in range(scales_per_octave)
-            ])
-            octave_scales = octave_scales * anchor_scale
+        for anchor_base in self.anchor_base_sizes:
+            octave_scales = np.array(
+                [2**(i / scales_per_octave) for i in range(scales_per_octave)])
+            anchor_scales = octave_scales * octave_base_scale
             self.anchor_generators.append(
-                AnchorGenerator(anchor_stride, octave_scales, anchor_ratios))
+                AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
         self.relu = nn.ReLU(inplace=True)
         self.num_anchors = int(
             len(self.anchor_ratios) * self.scales_per_octave)
-- 
GitLab