Skip to content
Snippets Groups Projects
Unverified Commit 754ec972 authored by Chrisfsj2051's avatar Chrisfsj2051 Committed by GitHub
Browse files

Implement NASFCOS (#2682)

* Implement NAS-FCOS

1. Supports NASFCOS
2. Added NASFCOS_FPN
3. Added NASFCOS_Head
4. Added r50 config

* Revert "Implement NAS-FCOS"

This reverts commit ebae11cbbdc5b4f3cc7226c763e9254c94a9c92b.

* Revert "Revert "Implement NAS-FCOS""

This reverts commit 32bca597859a447de62ddf0f2c2f482809730da9.

* Implement NASFCOS

1. Supports NASFCOS
2. Added NASFCOS_FPN
3. Added NASFCOS_Head
4. Added r50 config

* reformatted & use list to build convs in head

1. code refornatted
2. add docstring in nasfcos_head.py
3. use a list of conv config to build convs, rather than names

* cells merged

1. merge cells in nasfpn ans nasfcos_fpn
2. add nasfpn_cell_factory.py

* reformatted

* fixed

* fixed

* Update nasfpn_cell_factory.py

* pre-commit

* bug fixed

* renamed&moved

* not done

* fixed

* fixed

* Delete remove_parawise_cfg_nasfcos_r50_nashead_4x4_coco.py

* fixed

* Create 1x_bias_lr.py

* Revert "Create 1x_bias_lr.py"

This reverts commit 54602ded283a8f646aebfa0f216350fd2ce82e6d.

* fixed

* fixed

* fixed docstring

* init modified

* formatted

* fixed

* updated

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* pass test

* fixed

* fixed

* fixed

* rename

* 123

* nasfcos_readme

* Update README.md

* rename & leave empty

* GN->GN-head in README

* fixed
parent 50ffa248
No related branches found
No related tags found
No related merge requests found
Showing
with 720 additions and 78 deletions
......@@ -74,6 +74,7 @@ Results and models are available in the [model zoo](docs/model_zoo.md).
| ATSS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| FSAF | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| PAFPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| NAS-FCOS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| PISA | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
Other features
......
# NAS-FCOS: Fast Neural Architecture Search for Object Detection
## Introduction
```
@article{wang2019fcos,
title={Nas-fcos: Fast neural architecture search for object detection},
author={Wang, Ning and Gao, Yang and Chen, Hao and Wang, Peng and Tian, Zhi and Shen, Chunhua},
journal={arXiv preprint arXiv:1906.04423},
year={2019}
}
```
## Results and Models
| Head | Backbone | Style | GN-head | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
|:---------:|:---------:|:-------:|:-------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| NAS-FCOSHead | R-50 | caffe | Y | 1x | | | 39.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fcos/nas_fcos_nashead_r50_caffe_fpn_gn-head_4x4_1x_coco/nas_fcos_nashead_r50_caffe_fpn_gn-head_4x4_1x_coco_20200520-1bdba3ce.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fcos/nas_fcos_nashead_r50_caffe_fpn_gn-head_4x4_1x_coco/nas_fcos_nashead_r50_caffe_fpn_gn-head_4x4_1x_coco_20200520.log.json) |
| FCOSHead | R-50 | caffe | Y | 1x | | | 38.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fcos/nas_fcos_fcoshead_r50_caffe_fpn_gn-head_4x4_1x_coco/nas_fcos_fcoshead_r50_caffe_fpn_gn-head_4x4_1x_coco_20200521-7fdcbce0.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/nas_fcos/nas_fcos_fcoshead_r50_caffe_fpn_gn-head_4x4_1x_coco/nas_fcos_fcoshead_r50_caffe_fpn_gn-head_4x4_1x_coco_20200521.log.json) |
**Notes:**
- To be consistent with the author's implementation, we use 4 GPUs with 4 images/GPU.
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='NASFCOS',
pretrained='open-mmlab://resnet50_caffe_bgr',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False, eps=0),
style='caffe'),
neck=dict(
type='NASFCOS_FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5,
norm_cfg=dict(type='BN'),
conv_cfg=dict(type='DCNv2', deformable_groups=2)),
bbox_head=dict(
type='FCOSHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
norm_cfg=dict(type='GN', num_groups=32),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.6),
max_per_img=100)
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(
lr=0.01, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='NASFCOS',
pretrained='open-mmlab://resnet50_caffe_bgr',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False, eps=0),
style='caffe'),
neck=dict(
type='NASFCOS_FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5,
norm_cfg=dict(type='BN'),
conv_cfg=dict(type='DCNv2', deformable_groups=2)),
bbox_head=dict(
type='NASFCOSHead',
num_classes=80,
in_channels=256,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
norm_cfg=dict(type='GN', num_groups=32),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.6),
max_per_img=100)
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
optimizer = dict(
lr=0.01, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
......@@ -7,6 +7,7 @@ from .fsaf_head import FSAFHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .nasfcos_head import NASFCOSHead
from .pisa_retinanet_head import PISARetinaHead
from .pisa_ssd_head import PISASSDHead
from .reppoints_head import RepPointsHead
......@@ -19,5 +20,5 @@ __all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead', 'FSAFHead', 'PISARetinaHead', 'PISASSDHead'
'ATSSHead', 'FSAFHead', 'NASFCOSHead', 'PISARetinaHead', 'PISASSDHead'
]
import copy
import torch.nn as nn
from mmcv.cnn import (ConvModule, Scale, bias_init_with_prob,
caffe2_xavier_init, normal_init)
from mmdet.models.dense_heads.fcos_head import FCOSHead
from ..builder import HEADS
@HEADS.register_module()
class NASFCOSHead(FCOSHead):
"""Anchor-free head used in `NASFCOS <https://arxiv.org/abs/1906.04423>`_.
It is quite similar with FCOS head, except for the searched structure
of classification branch and bbox regression branch, where a structure
of "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead.
"""
def _init_layers(self):
dconv3x3_config = dict(
type='DCNv2',
kernel_size=3,
use_bias=True,
deformable_groups=2,
padding=1)
conv3x3_config = dict(type='Conv', kernel_size=3, padding=1)
conv1x1_config = dict(type='Conv', kernel_size=1)
self.arch_config = [
dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config
]
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i, op_ in enumerate(self.arch_config):
op = copy.deepcopy(op_)
chn = self.in_channels if i == 0 else self.feat_channels
assert isinstance(op, dict)
use_bias = op.pop('use_bias', False)
padding = op.pop('padding', 0)
kernel_size = op.pop('kernel_size')
module = ConvModule(
chn,
self.feat_channels,
kernel_size,
stride=1,
padding=padding,
norm_cfg=self.norm_cfg,
bias=use_bias,
conv_cfg=op)
self.cls_convs.append(copy.deepcopy(module))
self.reg_convs.append(copy.deepcopy(module))
self.fcos_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
self.fcos_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
self.fcos_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
def init_weights(self):
# retinanet_bias_init
bias_cls = bias_init_with_prob(0.01)
normal_init(self.fcos_reg, std=0.01)
normal_init(self.fcos_centerness, std=0.01)
normal_init(self.fcos_cls, std=0.01, bias=bias_cls)
for branch in [self.cls_convs, self.reg_convs]:
for module in branch.modules():
if isinstance(module, ConvModule) \
and isinstance(module.conv, nn.Conv2d):
caffe2_xavier_init(module.conv)
......@@ -10,6 +10,7 @@ from .grid_rcnn import GridRCNN
from .htc import HybridTaskCascade
from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .nasfcos import NASFCOS
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
......@@ -20,5 +21,5 @@ __all__ = [
'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
'FOVEA', 'FSAF'
'FOVEA', 'FSAF', 'NASFCOS'
]
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class NASFCOS(SingleStageDetector):
"""NAS-FCOS: Fast Neural Architecture Search for Object Detection.
https://arxiv.org/abs/1906.0442
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(NASFCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
......@@ -3,6 +3,9 @@ from .fpn import FPN
from .fpn_carafe import FPN_CARAFE
from .hrfpn import HRFPN
from .nas_fpn import NASFPN
from .nasfcos_fpn import NASFCOS_FPN
from .pafpn import PAFPN
__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN']
__all__ = [
'FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', 'NASFCOS_FPN'
]
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, caffe2_xavier_init
from mmdet.ops.merge_cells import GlobalPoolingCell, SumCell
from ..builder import NECKS
class MergingCell(nn.Module):
def __init__(self, channels=256, with_conv=True, norm_cfg=None):
super(MergingCell, self).__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv_out = ConvModule(
channels,
channels,
3,
padding=1,
norm_cfg=norm_cfg,
order=('act', 'conv', 'norm'))
def _binary_op(self, x1, x2):
raise NotImplementedError
def _resize(self, x, size):
if x.shape[-2:] == size:
return x
elif x.shape[-2:] < size:
return F.interpolate(x, size=size, mode='nearest')
else:
assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
kernel_size = x.shape[-1] // size[-1]
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x
def forward(self, x1, x2, out_size):
assert x1.shape[:2] == x2.shape[:2]
assert len(out_size) == 2
x1 = self._resize(x1, out_size)
x2 = self._resize(x2, out_size)
x = self._binary_op(x1, x2)
if self.with_conv:
x = self.conv_out(x)
return x
class SumCell(MergingCell):
def _binary_op(self, x1, x2):
return x1 + x2
class GPCell(MergingCell):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
def _binary_op(self, x1, x2):
x2_att = self.global_pool(x2).sigmoid()
return x2 + x2_att * x1
@NECKS.register_module()
class NASFPN(nn.Module):
"""NAS-FPN.
......@@ -126,21 +68,42 @@ class NASFPN(nn.Module):
for _ in range(self.stack_times):
stage = nn.ModuleDict()
# gp(p6, p4) -> p4_1
stage['gp_64_4'] = GPCell(out_channels, norm_cfg=norm_cfg)
stage['gp_64_4'] = GlobalPoolingCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
# sum(p4_1, p4) -> p4_2
stage['sum_44_4'] = SumCell(out_channels, norm_cfg=norm_cfg)
stage['sum_44_4'] = SumCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
# sum(p4_2, p3) -> p3_out
stage['sum_43_3'] = SumCell(out_channels, norm_cfg=norm_cfg)
stage['sum_43_3'] = SumCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
# sum(p3_out, p4_2) -> p4_out
stage['sum_34_4'] = SumCell(out_channels, norm_cfg=norm_cfg)
stage['sum_34_4'] = SumCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
# sum(p5, gp(p4_out, p3_out)) -> p5_out
stage['gp_43_5'] = GPCell(with_conv=False)
stage['sum_55_5'] = SumCell(out_channels, norm_cfg=norm_cfg)
stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
stage['sum_55_5'] = SumCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
# sum(p7, gp(p5_out, p4_2)) -> p7_out
stage['gp_54_7'] = GPCell(with_conv=False)
stage['sum_77_7'] = SumCell(out_channels, norm_cfg=norm_cfg)
stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
stage['sum_77_7'] = SumCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
# gp(p7_out, p5_out) -> p6_out
stage['gp_75_6'] = GPCell(out_channels, norm_cfg=norm_cfg)
stage['gp_75_6'] = GlobalPoolingCell(
in_channels=out_channels,
out_channels=out_channels,
out_norm_cfg=norm_cfg)
self.fpn_stages.append(stage)
def init_weights(self):
......
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, caffe2_xavier_init
from mmdet.ops.merge_cells import ConcatCell
from ..builder import NECKS
@NECKS.register_module
class NASFCOS_FPN(nn.Module):
"""FPN structure in NASFPN
NAS-FCOS: Fast Neural Architecture Search for Object Detection
<https://arxiv.org/abs/1906.04423>
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=1,
end_level=-1,
add_extra_convs=False,
conv_cfg=None,
norm_cfg=None):
super(NASFCOS_FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
self.adapt_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
adapt_conv = ConvModule(
in_channels[i],
out_channels,
1,
stride=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU', inplace=False))
self.adapt_convs.append(adapt_conv)
# C2 is omitted according to the paper
extra_levels = num_outs - self.backbone_end_level + self.start_level
def build_concat_cell(with_input1_conv, with_input2_conv):
cell_conv_cfg = dict(
kernel_size=1, padding=0, bias=False, groups=out_channels)
return ConcatCell(
in_channels=out_channels,
out_channels=out_channels,
with_out_conv=True,
out_conv_cfg=cell_conv_cfg,
out_norm_cfg=dict(type='BN'),
out_conv_order=('norm', 'act', 'conv'),
with_input1_conv=with_input1_conv,
with_input2_conv=with_input2_conv,
input_conv_cfg=conv_cfg,
input_norm_cfg=norm_cfg,
upsample_mode='nearest')
# Denote c3=f0, c4=f1, c5=f2 for convince
self.fpn = nn.ModuleDict()
self.fpn['c22_1'] = build_concat_cell(True, True)
self.fpn['c22_2'] = build_concat_cell(True, True)
self.fpn['c32'] = build_concat_cell(True, False)
self.fpn['c02'] = build_concat_cell(True, False)
self.fpn['c42'] = build_concat_cell(True, True)
self.fpn['c36'] = build_concat_cell(True, True)
self.fpn['c61'] = build_concat_cell(True, True) # f9
self.extra_downsamples = nn.ModuleList()
for i in range(extra_levels):
extra_act_cfg = None if i == 0 \
else dict(type='ReLU', inplace=False)
self.extra_downsamples.append(
ConvModule(
out_channels,
out_channels,
3,
stride=2,
padding=1,
act_cfg=extra_act_cfg,
order=('act', 'norm', 'conv')))
def forward(self, inputs):
feats = [
adapt_conv(inputs[i + self.start_level])
for i, adapt_conv in enumerate(self.adapt_convs)
]
for (i, module_name) in enumerate(self.fpn):
idx_1, idx_2 = int(module_name[1]), int(module_name[2])
res = self.fpn[module_name](feats[idx_1], feats[idx_2])
feats.append(res)
ret = []
for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5
feats1, feats2 = feats[idx], feats[5]
feats2_resize = F.interpolate(
feats2,
size=feats1.size()[2:],
mode='bilinear',
align_corners=False)
feats_sum = feats1 + feats2_resize
ret.append(
F.interpolate(
feats_sum,
size=inputs[input_idx].size()[2:],
mode='bilinear',
align_corners=False))
for submodule in self.extra_downsamples:
ret.append(submodule(ret[-1]))
return tuple(ret)
def init_weights(self):
for module in self.fpn.values():
if hasattr(module, 'conv_out'):
caffe2_xavier_init(module.out_conv.conv)
for modules in [
self.adapt_convs.modules(),
self.extra_downsamples.modules()
]:
for module in modules:
if isinstance(module, nn.Conv2d):
caffe2_xavier_init(module)
......@@ -356,9 +356,9 @@ class ModulatedDeformConv(nn.Module):
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.init_weights()
def reset_parameters(self):
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
......@@ -404,11 +404,13 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
self.init_weights()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def init_weights(self):
super(ModulatedDeformConvPack, self).init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
......
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
class BaseMergeCell(nn.Module):
"""The basic class for cells used in NAS-FPN and NAS-FCOS.
BaseMergeCell takes 2 inputs. After applying concolution
on them, they are resized to the target size. Then,
they go through binary_op, which depends on the type of cell.
If with_out_conv is True, the result of output will go through
another convolution layer.
Args:
in_channels (int): number of input channels in out_conv layer.
out_channels (int): number of output channels in out_conv layer.
with_out_conv (bool): Whether to use out_conv layer
out_conv_cfg (dict): Config dict for convolution layer, which should
contain "groups", "kernel_size", "padding", "bias" to build
out_conv layer.
out_norm_cfg (dict): Config dict for normalization layer in out_conv.
out_conv_order (tuple): The order of conv/norm/activation layers in
out_conv.
with_input1_conv (bool): Whether to use convolution on input1.
with_input2_conv (bool): Whether to use convolution on input2.
input_conv_cfg (dict): Config dict for building input1_conv layer and
input2_conv layer, which is expected to contain the type of
convolution.
Default: None, which means using conv2d.
input_norm_cfg (dict): Config dict for normalization layer in
input1_conv and input2_conv layer. Default: None.
upsample_mode (str): Interpolation method used to resize the output
of input1_conv and input2_conv to target size. Currently, we
support ['nearest', 'bilinear']. Default: 'nearest'.
"""
def __init__(self,
fused_channels=256,
out_channels=256,
with_out_conv=True,
out_conv_cfg=dict(
groups=1, kernel_size=3, padding=1, bias=True),
out_norm_cfg=None,
out_conv_order=('act', 'conv', 'norm'),
with_input1_conv=False,
with_input2_conv=False,
input_conv_cfg=None,
input_norm_cfg=None,
upsample_mode='nearest'):
super(BaseMergeCell, self).__init__()
assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv
self.with_input1_conv = with_input1_conv
self.with_input2_conv = with_input2_conv
self.upsample_mode = upsample_mode
if self.with_out_conv:
self.out_conv = ConvModule(
fused_channels,
out_channels,
**out_conv_cfg,
norm_cfg=out_norm_cfg,
order=out_conv_order)
self.input1_conv = self._build_input_conv(
out_channels, input_conv_cfg,
input_norm_cfg) if with_input1_conv else nn.Sequential()
self.input2_conv = self._build_input_conv(
out_channels, input_conv_cfg,
input_norm_cfg) if with_input2_conv else nn.Sequential()
def _build_input_conv(self, channel, conv_cfg, norm_cfg):
return ConvModule(
channel,
channel,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True)
@abstractmethod
def _binary_op(self, x1, x2):
pass
def _resize(self, x, size):
if x.shape[-2:] == size:
return x
elif x.shape[-2:] < size:
return F.interpolate(x, size=size, mode=self.upsample_mode)
else:
assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
kernel_size = x.shape[-1] // size[-1]
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x
def forward(self, x1, x2, out_size=None):
assert x1.shape[:2] == x2.shape[:2]
assert out_size is None or len(out_size) == 2
if out_size is None: # resize to larger one
out_size = max(x1.size()[2:], x2.size()[2:])
x1 = self.input1_conv(x1)
x2 = self.input2_conv(x2)
x1 = self._resize(x1, out_size)
x2 = self._resize(x2, out_size)
x = self._binary_op(x1, x2)
if self.with_out_conv:
x = self.out_conv(x)
return x
class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2):
return x1 + x2
class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
super(ConcatCell, self).__init__(in_channels * 2, out_channels,
**kwargs)
def _binary_op(self, x1, x2):
ret = torch.cat([x1, x2], dim=1)
return ret
class GlobalPoolingCell(BaseMergeCell):
def __init__(self, in_channels=None, out_channels=None, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
def _binary_op(self, x1, x2):
x2_att = self.global_pool(x2).sigmoid()
return x2 + x2_att * x1
matplotlib
mmcv>=0.5.1
mmcv>=0.5.3
numpy
# need older pillow until torchvision is fixed
Pillow<=6.2.2
......
"""
CommandLine:
pytest tests/test_merge_cells.py
"""
import torch
import torch.nn.functional as F
from mmdet.ops.merge_cells import (BaseMergeCell, ConcatCell,
GlobalPoolingCell, SumCell)
def test_sum_cell():
inputs_x = torch.randn([2, 256, 32, 32])
inputs_y = torch.randn([2, 256, 16, 16])
sum_cell = SumCell(256, 256)
output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert output.size() == inputs_x.size()
output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
assert output.size() == inputs_y.size()
output = sum_cell(inputs_x, inputs_y)
assert output.size() == inputs_x.size()
def test_concat_cell():
inputs_x = torch.randn([2, 256, 32, 32])
inputs_y = torch.randn([2, 256, 16, 16])
concat_cell = ConcatCell(256, 256)
output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert output.size() == inputs_x.size()
output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
assert output.size() == inputs_y.size()
output = concat_cell(inputs_x, inputs_y)
assert output.size() == inputs_x.size()
def test_global_pool_cell():
inputs_x = torch.randn([2, 256, 32, 32])
inputs_y = torch.randn([2, 256, 32, 32])
gp_cell = GlobalPoolingCell(with_out_conv=False)
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert (gp_cell_out.size() == inputs_x.size())
gp_cell = GlobalPoolingCell(256, 256)
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert (gp_cell_out.size() == inputs_x.size())
def test_resize_methods():
inputs_x = torch.randn([2, 256, 128, 128])
target_resize_sizes = [(128, 128), (256, 256)]
resize_methods_list = ['nearest', 'bilinear']
for method in resize_methods_list:
merge_cell = BaseMergeCell(upsample_mode=method)
for target_size in target_resize_sizes:
merge_cell_out = merge_cell._resize(inputs_x, target_size)
gt_out = F.interpolate(inputs_x, size=target_size, mode=method)
assert merge_cell_out.equal(gt_out)
target_size = (64, 64) # resize to a smaller size
merge_cell = BaseMergeCell()
merge_cell_out = merge_cell._resize(inputs_x, target_size)
kernel_size = inputs_x.shape[-1] // target_size[-1]
gt_out = F.max_pool2d(
inputs_x, kernel_size=kernel_size, stride=kernel_size)
assert (merge_cell_out == gt_out).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment