Skip to content
Snippets Groups Projects
Unverified Commit 4e921b2f authored by Chrisfsj2051's avatar Chrisfsj2051 Committed by GitHub
Browse files

Support TridentNet (#3313)


* go

* go

* go

* go

* remove unused

* Update trident_faster_rcnn.py

* fix nms for latest mmcv

* Delete tridentnet_r101_caffe_1x_coco.py

* fix ci

* Update tridentnet_r50_caffe_1x_coco.py

* add unit test

* Update test_backbones.py

* Update trident_roi_head.py

* add mstrain config

* update

* Update trident_roi_head.py

* Update trident_roi_head.py

* update accoding to comment

* Update trident_resnet.py

* update

* Update README.md

* Update README.md

* update

* Update trident_resnet.py

* Update tridentnet_r50_caffe_mstrain_1x_coco.py

* Update mmdet/models/backbones/trident_resnet.py

* Update resnet.py

* reformat to pass CI

Co-authored-by: default avatarJerry Jiarui XU <xvjiarui0826@gmail.com>
parent c8a620db
No related branches found
No related tags found
No related merge requests found
# Scale-Aware Trident Networks for Object Detection
## Introduction
```
@InProceedings{li2019scale,
title={Scale-Aware Trident Networks for Object Detection},
author={Li, Yanghao and Chen, Yuntao and Wang, Naiyan and Zhang, Zhaoxiang},
journal={The International Conference on Computer Vision (ICCV)},
year={2019}
}
```
## Results and models
We reports the test results using only one branch for inference.
| Backbone | Style | mstrain | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
| :-------------: | :-----: | :-----: | :-----: | :------: | :------------: | :----: | :------: |
| R-50 | caffe | N | 1x | | | 37.7 | |
| R-50 | caffe | Y | 1x | | | 37.6 | |
| R-50 | caffe | Y | 3x | | | 40.3 | |
_base_ = [
'../_base_/models/faster_rcnn_r50_caffe_c4.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='TridentFasterRCNN',
pretrained='open-mmlab://detectron2/resnet50_caffe',
backbone=dict(
type='TridentResNet',
trident_dilations=(1, 2, 3),
num_branch=3,
test_branch_idx=1),
roi_head=dict(type='TridentRoIHead', num_branch=3, test_branch_idx=1))
train_cfg = dict(
rpn_proposal=dict(nms_post=500, max_num=500),
rcnn=dict(
sampler=dict(num=128, pos_fraction=0.5, add_gt_as_proposals=False)))
# use caffe img_norm
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
_base_ = 'tridentnet_r50_caffe_1x_coco.py'
# use caffe img_norm
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(train=dict(pipeline=train_pipeline))
_base_ = 'tridentnet_r50_caffe_mstrain_1x_coco.py'
lr_config = dict(step=[28, 34])
total_epochs = 36
......@@ -9,9 +9,10 @@ from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .trident_resnet import TridentResNet
__all__ = [
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
'ResNeSt'
'ResNeSt', 'TridentResNet'
]
......@@ -262,7 +262,6 @@ class Bottleneck(nn.Module):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer, kaiming_init
from torch.nn.modules.utils import _pair
from mmdet.models.backbones.resnet import Bottleneck, ResNet
from mmdet.models.builder import BACKBONES
class TridentConv(nn.Module):
"""Trident Convolution Module.
Args:
in_channels (int): Number of channels in input.
out_channels (int): Number of channels in output.
kernel_size (int): Size of convolution kernel.
stride (int, optional): Convolution stride. Default: 1.
trident_dilations (tuple[int, int, int], optional): Dilations of
different trident branch. Default: (1, 2, 3).
test_branch_idx (int, optional): In inference, all 3 branches will
be used if `test_branch_idx==-1`, otherwise only branch with
index `test_branch_idx` will be used. Default: 1.
bias (bool, optional): Whether to use bias in convolution or not.
Default: False.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
trident_dilations=(1, 2, 3),
test_branch_idx=1,
bias=False):
super(TridentConv, self).__init__()
self.num_branch = len(trident_dilations)
self.with_bias = bias
self.test_branch_idx = test_branch_idx
self.stride = _pair(stride)
self.kernel_size = _pair(kernel_size)
self.paddings = _pair(trident_dilations)
self.dilations = trident_dilations
self.in_channels = in_channels
self.out_channels = out_channels
self.bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.init_weights()
def init_weights(self):
kaiming_init(self, distribution='uniform', mode='fan_in')
def extra_repr(self):
tmpstr = f'in_channels={self.in_channels}'
tmpstr += f', out_channels={self.out_channels}'
tmpstr += f', kernel_size={self.kernel_size}'
tmpstr += f', num_branch={self.num_branch}'
tmpstr += f', test_branch_idx={self.test_branch_idx}'
tmpstr += f', stride={self.stride}'
tmpstr += f', paddings={self.paddings}'
tmpstr += f', dilations={self.dilations}'
tmpstr += f', bias={self.bias}'
return tmpstr
def forward(self, inputs):
if self.training or self.test_branch_idx == -1:
outputs = [
F.conv2d(input, self.weight, self.bias, self.stride, padding,
dilation) for input, dilation, padding in zip(
inputs, self.dilations, self.paddings)
]
else:
assert len(inputs) == 1
outputs = [
F.conv2d(inputs[0], self.weight, self.bias, self.stride,
self.paddings[self.test_branch_idx],
self.dilations[self.test_branch_idx])
]
return outputs
# Since TridentNet is defined over ResNet50 and ResNet101, here we
# only support TridentBottleneckBlock.
class TridentBottleneck(Bottleneck):
"""BottleBlock for TridentResNet.
Args:
trident_dilations (tuple[int, int, int]): Dilations of different
trident branch.
test_branch_idx (int): In inference, all 3 branches will be used
if `test_branch_idx==-1`, otherwise only branch with index
`test_branch_idx` will be used.
concat_output (bool): Whether to concat the output list to a Tensor.
`True` only in the last Block.
"""
def __init__(self, trident_dilations, test_branch_idx, concat_output,
**kwargs):
super(TridentBottleneck, self).__init__(**kwargs)
self.trident_dilations = trident_dilations
self.num_branch = len(trident_dilations)
self.concat_output = concat_output
self.test_branch_idx = test_branch_idx
self.conv2 = TridentConv(
self.planes,
self.planes,
kernel_size=3,
stride=self.conv2_stride,
bias=False,
trident_dilations=self.trident_dilations,
test_branch_idx=test_branch_idx)
def forward(self, x):
def _inner_forward(x):
num_branch = (
self.num_branch
if self.training or self.test_branch_idx == -1 else 1)
identity = x
if not isinstance(x, list):
x = (x, ) * num_branch
identity = x
if self.downsample is not None:
identity = [self.downsample(b) for b in x]
out = [self.conv1(b) for b in x]
out = [self.norm1(b) for b in out]
out = [self.relu(b) for b in out]
if self.with_plugins:
for k in range(len(out)):
out[k] = self.forward_plugin(out[k],
self.after_conv1_plugin_names)
out = self.conv2(out)
out = [self.norm2(b) for b in out]
out = [self.relu(b) for b in out]
if self.with_plugins:
for k in range(len(out)):
out[k] = self.forward_plugin(out[k],
self.after_conv2_plugin_names)
out = [self.conv3(b) for b in out]
out = [self.norm3(b) for b in out]
if self.with_plugins:
for k in range(len(out)):
out[k] = self.forward_plugin(out[k],
self.after_conv3_plugin_names)
out = [
out_b + identity_b for out_b, identity_b in zip(out, identity)
]
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = [self.relu(b) for b in out]
if self.concat_output:
out = torch.cat(out, dim=0)
return out
def make_trident_res_layer(block,
inplanes,
planes,
num_blocks,
stride=1,
trident_dilations=(1, 2, 3),
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None,
test_branch_idx=-1):
"""Build Trident Res Layers."""
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = []
conv_stride = stride
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
for i in range(num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride if i == 0 else 1,
trident_dilations=trident_dilations,
downsample=downsample if i == 0 else None,
style=style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=plugins,
test_branch_idx=test_branch_idx,
concat_output=True if i == num_blocks - 1 else False))
inplanes = planes * block.expansion
return nn.Sequential(*layers)
@BACKBONES.register_module()
class TridentResNet(ResNet):
"""The stem layer, stage 1 and stage 2 in Trident ResNet are identical to
ResNet, while in stage 3, Trident BottleBlock is utilized to replace the
normal BottleBlock to yield trident output. Different branch shares the
convolution weight but uses different dilations to achieve multi-scale
output.
/ stage3(b0) \
x - stem - stage1 - stage2 - stage3(b1) - output
\ stage3(b2) /
Args:
depth (int): Depth of resnet, from {50, 101, 152}.
num_branch (int): Number of branches in TridentNet.
test_branch_idx (int): In inference, all 3 branches will be used
if `test_branch_idx==-1`, otherwise only branch with index
`test_branch_idx` will be used.
trident_dilations (tuple[int]): Dilations of different trident branch.
len(trident_dilations) should be equal to num_branch.
""" # noqa
def __init__(self, depth, num_branch, test_branch_idx, trident_dilations,
**kwargs):
assert num_branch == len(trident_dilations)
assert depth in (50, 101, 152)
super(TridentResNet, self).__init__(depth, **kwargs)
assert self.num_stages == 3
self.test_branch_idx = test_branch_idx
self.num_branch = num_branch
last_stage_idx = self.num_stages - 1
stride = self.strides[last_stage_idx]
dilation = trident_dilations
dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None
if self.plugins is not None:
stage_plugins = self.make_stage_plugins(self.plugins,
last_stage_idx)
else:
stage_plugins = None
planes = self.base_channels * 2**last_stage_idx
res_layer = make_trident_res_layer(
TridentBottleneck,
inplanes=(self.block.expansion * self.base_channels *
2**(last_stage_idx - 1)),
planes=planes,
num_blocks=self.stage_blocks[last_stage_idx],
stride=stride,
trident_dilations=dilation,
style=self.style,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=dcn,
plugins=stage_plugins,
test_branch_idx=self.test_branch_idx)
layer_name = f'layer{last_stage_idx + 1}'
self.__setattr__(layer_name, res_layer)
self.res_layers.pop(last_stage_idx)
self.res_layers.insert(last_stage_idx, layer_name)
self._freeze_stages()
......@@ -20,6 +20,7 @@ from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
from .single_stage import SingleStageDetector
from .trident_faster_rcnn import TridentFasterRCNN
from .two_stage import TwoStageDetector
from .vfnet import VFNet
from .yolact import YOLACT
......@@ -30,5 +31,5 @@ __all__ = [
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
'YOLOV3', 'YOLACT', 'VFNet', 'DETR'
'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN'
]
from ..builder import DETECTORS
from .faster_rcnn import FasterRCNN
@DETECTORS.register_module()
class TridentFasterRCNN(FasterRCNN):
"""Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None):
super(TridentFasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
assert self.backbone.num_branch == self.roi_head.num_branch
assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
self.num_branch = self.backbone.num_branch
self.test_branch_idx = self.backbone.test_branch_idx
def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Test without augmentation."""
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
trident_img_metas = img_metas * num_branch
proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
else:
proposal_list = proposals
return self.roi_head.simple_test(
x, proposal_list, trident_img_metas, rescale=rescale)
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
x = self.extract_feats(imgs)
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
return self.roi_head.aug_test(
x, proposal_list, img_metas, rescale=rescale)
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
"""make copies of img and gts to fit multi-branch."""
trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
trident_gt_labels = tuple(gt_labels * self.num_branch)
trident_img_metas = tuple(img_metas * self.num_branch)
return super(TridentFasterRCNN,
self).forward_train(img, trident_img_metas,
trident_gt_bboxes, trident_gt_labels)
......@@ -14,6 +14,7 @@ from .point_rend_roi_head import PointRendRoIHead
from .roi_extractors import SingleRoIExtractor
from .shared_heads import ResLayer
from .standard_roi_head import StandardRoIHead
from .trident_roi_head import TridentRoIHead
__all__ = [
'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead',
......@@ -22,5 +23,5 @@ __all__ = [
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'FCNMaskHead',
'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 'MaskIoUHead',
'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead',
'CoarseMaskHead', 'DynamicRoIHead'
'CoarseMaskHead', 'DynamicRoIHead', 'TridentRoIHead'
]
import torch
from mmcv.ops import batched_nms
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
multiclass_nms)
from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
from ..builder import HEADS
@HEADS.register_module()
class TridentRoIHead(StandardRoIHead):
"""Trident roi head.
Args:
num_branch (int): Number of branches in TridentNet.
test_branch_idx (int): In inference, all 3 branches will be used
if `test_branch_idx==-1`, otherwise only branch with index
`test_branch_idx` will be used.
"""
def __init__(self, num_branch, test_branch_idx, **kwargs):
self.num_branch = num_branch
self.test_branch_idx = test_branch_idx
super(TridentRoIHead, self).__init__(**kwargs)
def simple_test(self,
x,
proposal_list,
img_metas,
proposals=None,
rescale=False):
"""Test without augmentation as follows:
1. Compute prediction bbox and label per branch.
2. Merge predictions of each branch according to scores of
bboxes, i.e., bboxes with higher score are kept to give
top-k prediction.
"""
assert self.with_bbox, 'Bbox head must be implemented.'
det_bboxes_list, det_labels_list = self.simple_test_bboxes(
x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
for _ in range(len(det_bboxes_list)):
if det_bboxes_list[_].shape[0] == 0:
det_bboxes_list[_] = det_bboxes_list[_].new_empty((0, 5))
trident_det_bboxes = torch.cat(det_bboxes_list, 0)
trident_det_labels = torch.cat(det_labels_list, 0)
if trident_det_bboxes.numel() == 0:
det_bboxes = trident_det_bboxes.new_zeros((0, 5))
det_labels = trident_det_bboxes.new_zeros((0, ), dtype=torch.long)
else:
nms_bboxes = trident_det_bboxes[:, :4]
nms_scores = trident_det_bboxes[:, 4].contiguous()
nms_inds = trident_det_labels
nms_cfg = self.test_cfg['nms']
det_bboxes, keep = batched_nms(nms_bboxes, nms_scores, nms_inds,
nms_cfg)
det_labels = trident_det_labels[keep]
if self.test_cfg['max_per_img'] > 0:
det_labels = det_labels[:self.test_cfg['max_per_img']]
det_bboxes = det_bboxes[:self.test_cfg['max_per_img']]
det_bboxes, det_labels = [det_bboxes], [det_labels]
bbox_results = [
bbox2result(det_bboxes[i], det_labels[i],
self.bbox_head.num_classes)
for i in range(len(det_bboxes))
]
return bbox_results
def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
"""Test det bboxes with test time augmentation."""
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
flip_direction = img_meta[0]['flip_direction']
trident_bboxes, trident_scores = [], []
for branch_idx in range(len(proposal_list)):
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip, flip_direction)
rois = bbox2roi([proposals])
bbox_results = self._bbox_forward(x, rois)
bboxes, scores = self.bbox_head.get_bboxes(
rois,
bbox_results['cls_score'],
bbox_results['bbox_pred'],
img_shape,
scale_factor,
rescale=False,
cfg=None)
trident_bboxes.append(bboxes)
trident_scores.append(scores)
aug_bboxes.append(torch.cat(trident_bboxes, 0))
aug_scores.append(torch.cat(trident_scores, 0))
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels
......@@ -5,12 +5,13 @@ from torch.nn.modules import AvgPool2d, GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.backbones import (RegNet, Res2Net, ResNeSt, ResNet,
ResNetV1d, ResNeXt)
ResNetV1d, ResNeXt, TridentResNet)
from mmdet.models.backbones.hourglass import HourglassNet
from mmdet.models.backbones.res2net import Bottle2neck
from mmdet.models.backbones.resnest import Bottleneck as BottleneckS
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
from mmdet.models.backbones.trident_resnet import TridentBottleneck
from mmdet.models.utils import ResLayer
......@@ -51,7 +52,6 @@ def check_norm_state(modules, train_state):
def test_resnet_basic_block():
with pytest.raises(AssertionError):
# Not implemented yet.
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
......@@ -101,7 +101,6 @@ def test_resnet_basic_block():
def test_resnet_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
Bottleneck(64, 64, style='tensorflow')
......@@ -235,6 +234,181 @@ def test_resnet_bottleneck():
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_trident_resnet_bottleneck():
trident_dilations = (1, 2, 3)
test_branch_idx = 1
concat_output = True
trident_build_config = (trident_dilations, test_branch_idx, concat_output)
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
TridentBottleneck(
*trident_build_config, inplanes=64, planes=64, style='tensorflow')
with pytest.raises(AssertionError):
# Allowed positions are 'after_conv1', 'after_conv2', 'after_conv3'
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv4')
]
TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
with pytest.raises(AssertionError):
# Need to specify different postfix to avoid duplicate plugin name
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3'),
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
with pytest.raises(KeyError):
# Plugin type is not supported
plugins = [dict(cfg=dict(type='WrongPlugin'), position='after_conv3')]
TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
# Test Bottleneck with checkpoint forward
block = TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
# Test Bottleneck style
block = TridentBottleneck(
*trident_build_config,
inplanes=64,
planes=64,
stride=2,
style='pytorch')
assert block.conv1.stride == (1, 1)
assert block.conv2.stride == (2, 2)
block = TridentBottleneck(
*trident_build_config, inplanes=64, planes=64, stride=2, style='caffe')
assert block.conv1.stride == (2, 2)
assert block.conv2.stride == (1, 1)
# Test Bottleneck forward
block = TridentBottleneck(*trident_build_config, inplanes=64, planes=16)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
# Test Bottleneck with 1 ContextBlock after conv3
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
block = TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
assert block.context_block.in_channels == 64
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
# Test Bottleneck with 1 GeneralizedAttention after conv2
plugins = [
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0010',
kv_stride=2),
position='after_conv2')
]
block = TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
assert block.gen_attention_block.in_channels == 16
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
# Test Bottleneck with 1 GeneralizedAttention after conv2, 1 NonLocal2D
# after conv2, 1 ContextBlock after conv3
plugins = [
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0010',
kv_stride=2),
position='after_conv2'),
dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
block = TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
assert block.gen_attention_block.in_channels == 16
assert block.nonlocal_block.in_channels == 16
assert block.context_block.in_channels == 64
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
# Test Bottleneck with 1 ContextBlock after conv2, 2 ContextBlock after
# conv3
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
position='after_conv2'),
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
position='after_conv3'),
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=3),
position='after_conv3')
]
block = TridentBottleneck(
*trident_build_config, inplanes=64, planes=16, plugins=plugins)
assert block.context_block1.in_channels == 16
assert block.context_block2.in_channels == 64
assert block.context_block3.in_channels == 64
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
def test_trident_resnet_backbone():
tridentresnet_config = dict(
num_branch=3,
test_branch_idx=1,
strides=(1, 2, 2),
dilations=(1, 1, 1),
trident_dilations=(1, 2, 3),
out_indices=(2, ),
)
"""Test tridentresnet backbone."""
with pytest.raises(AssertionError):
# TridentResNet depth should be in [50, 101, 152]
TridentResNet(18, **tridentresnet_config)
with pytest.raises(AssertionError):
# In TridentResNet: num_stages == 3
TridentResNet(50, num_stages=4, **tridentresnet_config)
model = TridentResNet(50, num_stages=3, **tridentresnet_config)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([3, 1024, 14, 14])
def test_resnet_res_layer():
# Test ResLayer of 3 Bottleneck w\o downsample
layer = ResLayer(Bottleneck, 64, 16, 3)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment