From efd02f86182c5f16531c231ed7e05e0952c78fb3 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Fri, 23 Nov 2018 20:49:53 +0800 Subject: [PATCH] refactoring for retinanet --- MODEL_ZOO.md | 8 ++--- configs/retinanet_r50_fpn_1x.py | 8 ++--- mmdet/core/loss/losses.py | 2 +- .../models/single_stage_heads/retina_head.py | 36 ++++++++++++------- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md index ed526d5e..e581f17f 100644 --- a/MODEL_ZOO.md +++ b/MODEL_ZOO.md @@ -71,13 +71,13 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m | R-50-FPN | pytorch | Mask | 1x | 5.3 | 0.50 | 10.6 | 36.8 | 34.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_1x_20181010-e030a38f.pth) \| [result](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_1x_20181010_results.pkl.json) | | R-50-FPN | pytorch | Mask | 2x | 5.3 | 0.50 | 10.6 | 37.9 | 34.8 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_2x_20181010-5048cb03.pth) \| [result](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_2x_20181010_results.pkl.json) | -### RetinaNet (coming soon) +### RetinaNet | Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download | |:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:| -| R-50-FPN | caffe | 1x | | | | | | -| R-50-FPN | pytorch | 1x | | | | | | -| R-50-FPN | pytorch | 2x | | | | | | +| R-50-FPN | caffe | 1x | 5.8 | 0.459 | 9.4 | | - | +| R-50-FPN | pytorch | 1x | 6.2 | 0.487 | 9.1 | | | +| R-50-FPN | pytorch | 2x | 6.2 | 0.487 | 9.1 | | | ## Comparison with Detectron diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py index 64423bb3..079d6e53 100644 --- a/configs/retinanet_r50_fpn_1x.py +++ b/configs/retinanet_r50_fpn_1x.py @@ -22,16 +22,16 @@ model = dict( in_channels=256, stacked_convs=4, feat_channels=256, + octave_base_scale=4, scales_per_octave=3, - anchor_scale=4, - anchor_ratios=[1.0, 2.0, 0.5], + anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[8, 16, 32, 64, 128], target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0])) # training and testing settings train_cfg = dict( assigner=dict( - pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0.4, ignore_iof_thr=-1), + pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), smoothl1_beta=0.11, gamma=2.0, alpha=0.25, @@ -99,7 +99,7 @@ lr_config = dict( checkpoint_config = dict(interval=1) # yapf:disable log_config = dict( - interval=20, + interval=50, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook') diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py index 14b49f5c..af1ceae7 100644 --- a/mmdet/core/loss/losses.py +++ b/mmdet/core/loss/losses.py @@ -31,7 +31,7 @@ def sigmoid_focal_loss(pred, gamma=2.0, alpha=0.25, reduction='elementwise_mean'): - pred_sigmoid = pred.sigmoid() + pred_sigmoid = pred.sigmoid().detach() pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = weight * pt.pow(gamma) diff --git a/mmdet/models/single_stage_heads/retina_head.py b/mmdet/models/single_stage_heads/retina_head.py index 44ddaa6b..cc049239 100644 --- a/mmdet/models/single_stage_heads/retina_head.py +++ b/mmdet/models/single_stage_heads/retina_head.py @@ -23,6 +23,14 @@ class RetinaHead(nn.Module): stacked_convs (int): Number of convolutional layers added for cls and reg branch. feat_channels (int): Number of channels for the RPN feature map. + scales_per_octave (int): Number of anchor scales per octave. + octave_base_scale (int): Base octave scale. Anchor scales are computed + as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale` + and n is `scales_per_octave`. + anchor_ratios (Iterable): Anchor aspect ratios. + anchor_strides (Iterable): Anchor strides. + target_means (Iterable): Mean values of regression targets. + target_stds (Iterable): Std values of regression targets. """ def __init__(self, @@ -30,30 +38,32 @@ class RetinaHead(nn.Module): num_classes, stacked_convs=4, feat_channels=256, + octave_base_scale=4, scales_per_octave=3, - anchor_scale=4, - anchor_ratios=[1.0, 2.0, 0.5], + anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[8, 16, 32, 64, 128], - target_means=[.0, .0, .0, .0], - target_stds=[1.0, 1.0, 1.0, 1.0]): + anchor_base_sizes=None, + target_means=(.0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0)): super(RetinaHead, self).__init__() self.in_channels = in_channels self.num_classes = num_classes + self.octave_base_scale = octave_base_scale self.scales_per_octave = scales_per_octave - self.anchor_scale = anchor_scale - self.anchor_strides = anchor_strides self.anchor_ratios = anchor_ratios + self.anchor_strides = anchor_strides + self.anchor_base_sizes = list( + anchor_strides) if anchor_base_sizes is None else anchor_base_sizes self.target_means = target_means self.target_stds = target_stds + self.anchor_generators = [] - for anchor_stride in self.anchor_strides: - octave_scales = np.array([ - 2**(octave / float(scales_per_octave)) - for octave in range(scales_per_octave) - ]) - octave_scales = octave_scales * anchor_scale + for anchor_base in self.anchor_base_sizes: + octave_scales = np.array( + [2**(i / scales_per_octave) for i in range(scales_per_octave)]) + anchor_scales = octave_scales * octave_base_scale self.anchor_generators.append( - AnchorGenerator(anchor_stride, octave_scales, anchor_ratios)) + AnchorGenerator(anchor_base, anchor_scales, anchor_ratios)) self.relu = nn.ReLU(inplace=True) self.num_anchors = int( len(self.anchor_ratios) * self.scales_per_octave) -- GitLab