diff --git a/.isort.cfg b/.isort.cfg index 06ae39a2a1c5059808dbc41dbe3b1f6f6986375f..9c4f8afdfce9632cfa143cf83aa4ea737f3b787f 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision +known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index 6fb56d63c8898a83f7b779c8c22cf52880f202c5..e96a224e3f5d45bca7ca1da3f7ec517b3db34dff 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -1,6 +1,6 @@ from .hrnet import HRNet -from .resnet import ResNet, make_res_layer +from .resnet import ResNet, ResNetV1d from .resnext import ResNeXt from .ssd_vgg import SSDVGG -__all__ = ['ResNet', 'make_res_layer', 'ResNeXt', 'SSDVGG', 'HRNet'] +__all__ = ['ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet'] diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index 5a3862ae004c3f32cab6654858e9881babb21bf4..f4008e1178b68102270b6f2b3ea79a9227a8d2c2 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -8,6 +8,7 @@ from mmdet.ops import (ContextBlock, GeneralizedAttention, build_conv_layer, build_norm_layer) from mmdet.utils import get_root_logger from ..registry import BACKBONES +from ..utils import ResLayer class BasicBlock(nn.Module): @@ -156,7 +157,7 @@ class Bottleneck(nn.Module): dilation=dilation, bias=False) else: - assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN' + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' self.conv2 = build_conv_layer( dcn, planes, @@ -239,69 +240,6 @@ class Bottleneck(nn.Module): return out -def make_res_layer(block, - inplanes, - planes, - blocks, - stride=1, - dilation=1, - style='pytorch', - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dcn=None, - gcb=None, - gen_attention=None, - gen_attention_blocks=[]): - downsample = None - if stride != 1 or inplanes != planes * block.expansion: - downsample = nn.Sequential( - build_conv_layer( - conv_cfg, - inplanes, - planes * block.expansion, - kernel_size=1, - stride=stride, - bias=False), - build_norm_layer(norm_cfg, planes * block.expansion)[1], - ) - - layers = [] - layers.append( - block( - inplanes=inplanes, - planes=planes, - stride=stride, - dilation=dilation, - downsample=downsample, - style=style, - with_cp=with_cp, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - dcn=dcn, - gcb=gcb, - gen_attention=gen_attention if - (0 in gen_attention_blocks) else None)) - inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append( - block( - inplanes=inplanes, - planes=planes, - stride=1, - dilation=dilation, - style=style, - with_cp=with_cp, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - dcn=dcn, - gcb=gcb, - gen_attention=gen_attention if - (i in gen_attention_blocks) else None)) - - return nn.Sequential(*layers) - - @BACKBONES.register_module class ResNet(nn.Module): """ResNet backbone. @@ -316,6 +254,9 @@ class ResNet(nn.Module): style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. norm_cfg (dict): dictionary to construct and config norm layer. @@ -353,11 +294,14 @@ class ResNet(nn.Module): def __init__(self, depth, in_channels=3, + base_channels=64, num_stages=4, strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), out_indices=(0, 1, 2, 3), style='pytorch', + deep_stem=False, + avg_down=False, frozen_stages=-1, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), @@ -374,6 +318,7 @@ class ResNet(nn.Module): if depth not in self.arch_settings: raise KeyError('invalid depth {} for resnet'.format(depth)) self.depth = depth + self.base_channels = base_channels self.num_stages = num_stages assert num_stages >= 1 and num_stages <= 4 self.strides = strides @@ -382,6 +327,8 @@ class ResNet(nn.Module): self.out_indices = out_indices assert max(out_indices) < num_stages self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -399,9 +346,9 @@ class ResNet(nn.Module): self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] - self.inplanes = 64 + self.inplanes = base_channels - self._make_stem_layer(in_channels) + self._make_stem_layer(in_channels, base_channels) self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): @@ -409,15 +356,16 @@ class ResNet(nn.Module): dilation = dilations[i] dcn = self.dcn if self.stage_with_dcn[i] else None gcb = self.gcb if self.stage_with_gcb[i] else None - planes = 64 * 2**i - res_layer = make_res_layer( - self.block, - self.inplanes, - planes, - num_blocks, + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, stride=stride, dilation=dilation, style=self.style, + avg_down=self.avg_down, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, @@ -432,33 +380,75 @@ class ResNet(nn.Module): self._freeze_stages() - self.feat_dim = self.block.expansion * 64 * 2**( + self.feat_dim = self.block.expansion * base_channels * 2**( len(self.stage_blocks) - 1) + def make_res_layer(self, **kwargs): + return ResLayer(**kwargs) + @property def norm1(self): return getattr(self, self.norm1_name) - def _make_stem_layer(self, in_channels): - self.conv1 = build_conv_layer( - self.conv_cfg, - in_channels, - 64, - kernel_size=7, - stride=2, - padding=3, - bias=False) - self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) - self.add_module(self.norm1_name, norm1) - self.relu = nn.ReLU(inplace=True) + def _make_stem_layer(self, in_channels, base_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + base_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, base_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + base_channels // 2, + base_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, base_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + base_channels // 2, + base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, base_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def _freeze_stages(self): if self.frozen_stages >= 0: - self.norm1.eval() - for m in [self.conv1, self.norm1]: - for param in m.parameters(): + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, 'layer{}'.format(i)) @@ -493,9 +483,12 @@ class ResNet(nn.Module): raise TypeError('pretrained must be a str or None') def forward(self, x): - x = self.conv1(x) - x = self.norm1(x) - x = self.relu(x) + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): @@ -513,3 +506,21 @@ class ResNet(nn.Module): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() + + +@BACKBONES.register_module +class ResNetV1d(ResNet): + """ResNetV1d variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv + in the input stem with three 3x3 convs. And in the downsampling block, + a 2x2 avg_pool with stride 2 is added before conv, whose stride is + changed to 1. + + References: + .. [1] https://arxiv.org/pdf/1812.01187.pdf + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py index 69364f77915f06140b8ee290d1cdadf80e6fe71a..86fe14c08c1a39e0e11450e205233421ead65bc6 100644 --- a/mmdet/models/backbones/resnext.py +++ b/mmdet/models/backbones/resnext.py @@ -1,16 +1,21 @@ import math -import torch.nn as nn - from mmdet.ops import build_conv_layer, build_norm_layer from ..registry import BACKBONES +from ..utils import ResLayer from .resnet import Bottleneck as _Bottleneck from .resnet import ResNet class Bottleneck(_Bottleneck): - def __init__(self, inplanes, planes, groups=1, base_width=4, **kwargs): + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): """Bottleneck block for ResNeXt. If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is "caffe", the stride-two layer is the first 1x1 conv layer. @@ -20,7 +25,8 @@ class Bottleneck(_Bottleneck): if groups == 1: width = self.planes else: - width = math.floor(self.planes * (base_width / 64)) * groups + width = math.floor(self.planes * + (base_width / base_channels)) * groups self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, width, postfix=1) @@ -75,69 +81,6 @@ class Bottleneck(_Bottleneck): self.add_module(self.norm3_name, norm3) -def make_res_layer(block, - inplanes, - planes, - blocks, - stride=1, - dilation=1, - groups=1, - base_width=4, - style='pytorch', - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dcn=None, - gcb=None): - downsample = None - if stride != 1 or inplanes != planes * block.expansion: - downsample = nn.Sequential( - build_conv_layer( - conv_cfg, - inplanes, - planes * block.expansion, - kernel_size=1, - stride=stride, - bias=False), - build_norm_layer(norm_cfg, planes * block.expansion)[1], - ) - - layers = [] - layers.append( - block( - inplanes=inplanes, - planes=planes, - stride=stride, - dilation=dilation, - downsample=downsample, - groups=groups, - base_width=base_width, - style=style, - with_cp=with_cp, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - dcn=dcn, - gcb=gcb)) - inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append( - block( - inplanes=inplanes, - planes=planes, - stride=1, - dilation=dilation, - groups=groups, - base_width=base_width, - style=style, - with_cp=with_cp, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - dcn=dcn, - gcb=gcb)) - - return nn.Sequential(*layers) - - @BACKBONES.register_module class ResNeXt(ResNet): """ResNeXt backbone. @@ -187,36 +130,13 @@ class ResNeXt(ResNet): } def __init__(self, groups=1, base_width=4, **kwargs): - super(ResNeXt, self).__init__(**kwargs) self.groups = groups self.base_width = base_width + super(ResNeXt, self).__init__(**kwargs) - self.inplanes = 64 - self.res_layers = [] - for i, num_blocks in enumerate(self.stage_blocks): - stride = self.strides[i] - dilation = self.dilations[i] - dcn = self.dcn if self.stage_with_dcn[i] else None - gcb = self.gcb if self.stage_with_gcb[i] else None - planes = 64 * 2**i - res_layer = make_res_layer( - self.block, - self.inplanes, - planes, - num_blocks, - stride=stride, - dilation=dilation, - groups=self.groups, - base_width=self.base_width, - style=self.style, - with_cp=self.with_cp, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - dcn=dcn, - gcb=gcb) - self.inplanes = planes * self.block.expansion - layer_name = 'layer{}'.format(i + 1) - self.add_module(layer_name, res_layer) - self.res_layers.append(layer_name) - - self._freeze_stages() + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py index e1a1ba0d76b34d6199ba397916e6b29ade8e0a74..a5178c334243a6da5f09edd77ec8104a8b6c1e3a 100644 --- a/mmdet/models/shared_heads/res_layer.py +++ b/mmdet/models/shared_heads/res_layer.py @@ -4,8 +4,9 @@ from mmcv.runner import load_checkpoint from mmdet.core import auto_fp16 from mmdet.utils import get_root_logger -from ..backbones import ResNet, make_res_layer +from ..backbones import ResNet from ..registry import SHARED_HEADS +from ..utils import ResLayer as _ResLayer @SHARED_HEADS.register_module @@ -31,7 +32,7 @@ class ResLayer(nn.Module): planes = 64 * 2**stage inplanes = 64 * 2**(stage - 1) * block.expansion - res_layer = make_res_layer( + res_layer = _ResLayer( block, inplanes, planes, diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py index 83614f71954ca16f2f1dee91d94845e335d21a80..47213871f05e1656eb7f09fd986add8135ef1cff 100644 --- a/mmdet/models/utils/__init__.py +++ b/mmdet/models/utils/__init__.py @@ -1,3 +1,4 @@ +from .res_layer import ResLayer from .weight_init import bias_init_with_prob -__all__ = ['bias_init_with_prob'] +__all__ = ['bias_init_with_prob', 'ResLayer'] diff --git a/mmdet/models/utils/res_layer.py b/mmdet/models/utils/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8c61f285cce027b6349d2b167d2c2c7b32d141 --- /dev/null +++ b/mmdet/models/utils/res_layer.py @@ -0,0 +1,85 @@ +from torch import nn as nn + +from mmdet.ops import build_conv_layer, build_norm_layer + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + gen_attention = kwargs.pop('gen_attention', None) + gen_attention_blocks = kwargs.pop('gen_attention_blocks', tuple()) + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + gen_attention=gen_attention + if 0 in gen_attention_blocks else None, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + gen_attention=gen_attention if + (i in gen_attention_blocks) else None, + **kwargs)) + super(ResLayer, self).__init__(*layers) diff --git a/tests/test_backbone.py b/tests/test_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..9f8ab685a6fc9a1ac47f5ca85583f1f3059fc6e3 --- /dev/null +++ b/tests/test_backbone.py @@ -0,0 +1,401 @@ +import pytest +import torch +from torch.nn.modules import AvgPool2d, GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.models.backbones import ResNet, ResNetV1d, ResNeXt +from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +from mmdet.models.backbones.resnext import Bottleneck as BottleneckX +from mmdet.models.utils import ResLayer +from mmdet.ops import DeformConvPack + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def all_zeros(modules): + weight_zero = torch.allclose(modules.weight.data, + torch.zeros_like(modules.weight.data)) + if hasattr(modules, 'bias'): + bias_zero = torch.allclose(modules.bias.data, + torch.zeros_like(modules.bias.data)) + else: + bias_zero = True + + return weight_zero and bias_zero + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_resnet_basic_block(): + + with pytest.raises(AssertionError): + BasicBlock(64, 64, with_cp=True) + + with pytest.raises(AssertionError): + # Not implemented yet. + dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False) + BasicBlock(64, 64, dcn=dcn) + + with pytest.raises(AssertionError): + # Not implemented yet. + gcb = dict(ratio=1. / 4., ) + BasicBlock(64, 64, gcb=gcb) + + with pytest.raises(AssertionError): + # Not implemented yet + gen_attention = dict( + spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2) + BasicBlock(64, 64, gen_attention=gen_attention) + + block = BasicBlock(64, 64) + assert block.conv1.in_channels == 64 + assert block.conv1.out_channels == 64 + assert block.conv1.kernel_size == (3, 3) + assert block.conv2.in_channels == 64 + assert block.conv2.out_channels == 64 + assert block.conv2.kernel_size == (3, 3) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnet_bottleneck(): + + with pytest.raises(AssertionError): + # style must be in ['pytorch', 'caffe'] + Bottleneck(64, 64, style='tensorflow') + + block = Bottleneck(64, 16, with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + block = Bottleneck(64, 64, stride=2, style='pytorch') + assert block.conv1.stride == (1, 1) + assert block.conv2.stride == (2, 2) + block = Bottleneck(64, 64, stride=2, style='caffe') + assert block.conv1.stride == (2, 2) + assert block.conv2.stride == (1, 1) + + dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False) + with pytest.raises(AssertionError): + Bottleneck(64, 64, dcn=dcn, conv_cfg=dict(type='Conv')) + block = Bottleneck(64, 64, dcn=dcn) + assert isinstance(block.conv2, DeformConvPack) + + block = Bottleneck(64, 16) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + gcb = dict(ratio=1. / 4., ) + block = Bottleneck(64, 16, gcb=gcb) + assert hasattr(block, 'context_block') + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + gen_attention = dict( + spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2) + block = Bottleneck(64, 16, gen_attention=gen_attention) + assert hasattr(block, 'gen_attention') + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnet_res_layer(): + layer = ResLayer(Bottleneck, 64, 16, 3) + assert len(layer) == 3 + assert layer[0].conv1.in_channels == 64 + assert layer[0].conv1.out_channels == 16 + for i in range(1, len(layer)): + assert layer[i].conv1.in_channels == 64 + assert layer[i].conv1.out_channels == 16 + for i in range(len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + layer = ResLayer(Bottleneck, 64, 64, 3) + assert layer[0].downsample[0].out_channels == 256 + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 56, 56]) + + layer = ResLayer(Bottleneck, 64, 64, 3, stride=2) + assert layer[0].downsample[0].out_channels == 256 + assert layer[0].downsample[0].stride == (2, 2) + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 28, 28]) + + layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True) + assert isinstance(layer[0].downsample[0], AvgPool2d) + assert layer[0].downsample[1].out_channels == 256 + assert layer[0].downsample[1].stride == (1, 1) + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 28, 28]) + + +def test_resnet_backbone(): + """Test resnet backbone""" + with pytest.raises(KeyError): + # ResNet depth should be in [18, 34, 50, 101, 152] + ResNet(20) + + with pytest.raises(AssertionError): + # In ResNet: 1 <= num_stages <= 4 + ResNet(50, num_stages=0) + + with pytest.raises(AssertionError): + ResNet(18, with_cp=True) + + with pytest.raises(AssertionError): + # len(stage_with_dcn) == num_stages + dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False) + ResNet(50, dcn=dcn, stage_with_dcn=(True, )) + + with pytest.raises(AssertionError): + # len(stage_with_gcb) == num_stages + gcb = dict(ratio=1. / 4., ) + ResNet(50, gcb=gcb, stage_with_gcb=(True, )) + + with pytest.raises(AssertionError): + # In ResNet: 1 <= num_stages <= 4 + ResNet(50, num_stages=5) + + with pytest.raises(AssertionError): + # len(strides) == len(dilations) == num_stages + ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3) + + with pytest.raises(TypeError): + model = ResNet(50) + model.init_weights(pretrained=0) + + with pytest.raises(AssertionError): + # style must be in ['pytorch', 'caffe'] + ResNet(50, style='tensorflow') + + with pytest.raises(AssertionError): + # assert not with_cp + ResNet(18, with_cp=True) + + model = ResNet(18) + model.init_weights() + + model = ResNet(50, norm_eval=True) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), False) + + model = ResNet(depth=50, norm_eval=True) + model.init_weights('torchvision://resnet50') + model.train() + assert check_norm_state(model.modules(), False) + + frozen_stages = 1 + model = ResNet(50, frozen_stages=frozen_stages) + model.init_weights() + model.train() + assert model.norm1.training is False + for layer in [model.conv1, model.norm1]: + for param in layer.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, 'layer{}'.format(i)) + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + model = ResNetV1d(depth=50, frozen_stages=frozen_stages) + assert len(model.stem) == 9 + model.init_weights() + model.train() + check_norm_state(model.stem, False) + for param in model.stem.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, 'layer{}'.format(i)) + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + model = ResNet(18) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 64, 56, 56]) + assert feat[1].shape == torch.Size([1, 128, 28, 28]) + assert feat[2].shape == torch.Size([1, 256, 14, 14]) + assert feat[3].shape == torch.Size([1, 512, 7, 7]) + + model = ResNet(50) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + model = ResNet(50, out_indices=(0, 1, 2)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + + model = ResNet(50, with_cp=True) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + model = ResNet( + 50, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + model = ResNet(50, zero_init_residual=True) + model.init_weights() + for m in model.modules(): + if isinstance(m, Bottleneck): + assert all_zeros(m.norm3) + elif isinstance(m, BasicBlock): + assert all_zeros(m.norm2) + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + model = ResNetV1d(depth=50) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + +def test_renext_bottleneck(): + with pytest.raises(AssertionError): + # style must be in ['pytorch', 'caffe'] + BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') + + block = BottleneckX( + 64, 64, groups=32, base_width=4, stride=2, style='pytorch') + assert block.conv2.stride == (2, 2) + assert block.conv2.groups == 32 + assert block.conv2.out_channels == 128 + + dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False) + with pytest.raises(AssertionError): + BottleneckX( + 64, + 64, + groups=32, + base_width=4, + dcn=dcn, + conv_cfg=dict(type='Conv')) + BottleneckX(64, 64, dcn=dcn) + + block = BottleneckX(64, 16, groups=32, base_width=4) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnext_backbone(): + with pytest.raises(KeyError): + # ResNeXt depth should be in [50, 101, 152] + ResNeXt(depth=18) + + model = ResNeXt(depth=50, groups=32, base_width=4) + for m in model.modules(): + if is_block(m): + assert m.conv2.groups == 32 + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7])