diff --git a/.isort.cfg b/.isort.cfg
index 06ae39a2a1c5059808dbc41dbe3b1f6f6986375f..9c4f8afdfce9632cfa143cf83aa4ea737f3b787f 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -3,6 +3,6 @@ line_length = 79
 multi_line_output = 0
 known_standard_library = setuptools
 known_first_party = mmdet
-known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
+known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
 no_lines_before = STDLIB,LOCALFOLDER
 default_section = THIRDPARTY
diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py
index 6fb56d63c8898a83f7b779c8c22cf52880f202c5..e96a224e3f5d45bca7ca1da3f7ec517b3db34dff 100644
--- a/mmdet/models/backbones/__init__.py
+++ b/mmdet/models/backbones/__init__.py
@@ -1,6 +1,6 @@
 from .hrnet import HRNet
-from .resnet import ResNet, make_res_layer
+from .resnet import ResNet, ResNetV1d
 from .resnext import ResNeXt
 from .ssd_vgg import SSDVGG
 
-__all__ = ['ResNet', 'make_res_layer', 'ResNeXt', 'SSDVGG', 'HRNet']
+__all__ = ['ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet']
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index 5a3862ae004c3f32cab6654858e9881babb21bf4..f4008e1178b68102270b6f2b3ea79a9227a8d2c2 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -8,6 +8,7 @@ from mmdet.ops import (ContextBlock, GeneralizedAttention, build_conv_layer,
                        build_norm_layer)
 from mmdet.utils import get_root_logger
 from ..registry import BACKBONES
+from ..utils import ResLayer
 
 
 class BasicBlock(nn.Module):
@@ -156,7 +157,7 @@ class Bottleneck(nn.Module):
                 dilation=dilation,
                 bias=False)
         else:
-            assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN'
+            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
             self.conv2 = build_conv_layer(
                 dcn,
                 planes,
@@ -239,69 +240,6 @@ class Bottleneck(nn.Module):
         return out
 
 
-def make_res_layer(block,
-                   inplanes,
-                   planes,
-                   blocks,
-                   stride=1,
-                   dilation=1,
-                   style='pytorch',
-                   with_cp=False,
-                   conv_cfg=None,
-                   norm_cfg=dict(type='BN'),
-                   dcn=None,
-                   gcb=None,
-                   gen_attention=None,
-                   gen_attention_blocks=[]):
-    downsample = None
-    if stride != 1 or inplanes != planes * block.expansion:
-        downsample = nn.Sequential(
-            build_conv_layer(
-                conv_cfg,
-                inplanes,
-                planes * block.expansion,
-                kernel_size=1,
-                stride=stride,
-                bias=False),
-            build_norm_layer(norm_cfg, planes * block.expansion)[1],
-        )
-
-    layers = []
-    layers.append(
-        block(
-            inplanes=inplanes,
-            planes=planes,
-            stride=stride,
-            dilation=dilation,
-            downsample=downsample,
-            style=style,
-            with_cp=with_cp,
-            conv_cfg=conv_cfg,
-            norm_cfg=norm_cfg,
-            dcn=dcn,
-            gcb=gcb,
-            gen_attention=gen_attention if
-            (0 in gen_attention_blocks) else None))
-    inplanes = planes * block.expansion
-    for i in range(1, blocks):
-        layers.append(
-            block(
-                inplanes=inplanes,
-                planes=planes,
-                stride=1,
-                dilation=dilation,
-                style=style,
-                with_cp=with_cp,
-                conv_cfg=conv_cfg,
-                norm_cfg=norm_cfg,
-                dcn=dcn,
-                gcb=gcb,
-                gen_attention=gen_attention if
-                (i in gen_attention_blocks) else None))
-
-    return nn.Sequential(*layers)
-
-
 @BACKBONES.register_module
 class ResNet(nn.Module):
     """ResNet backbone.
@@ -316,6 +254,9 @@ class ResNet(nn.Module):
         style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
             layer is the 3x3 conv layer, otherwise the stride-two layer is
             the first 1x1 conv layer.
+        deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+        avg_down (bool): Use AvgPool instead of stride conv when
+            downsampling in the bottleneck.
         frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
             -1 means not freezing any parameters.
         norm_cfg (dict): dictionary to construct and config norm layer.
@@ -353,11 +294,14 @@ class ResNet(nn.Module):
     def __init__(self,
                  depth,
                  in_channels=3,
+                 base_channels=64,
                  num_stages=4,
                  strides=(1, 2, 2, 2),
                  dilations=(1, 1, 1, 1),
                  out_indices=(0, 1, 2, 3),
                  style='pytorch',
+                 deep_stem=False,
+                 avg_down=False,
                  frozen_stages=-1,
                  conv_cfg=None,
                  norm_cfg=dict(type='BN', requires_grad=True),
@@ -374,6 +318,7 @@ class ResNet(nn.Module):
         if depth not in self.arch_settings:
             raise KeyError('invalid depth {} for resnet'.format(depth))
         self.depth = depth
+        self.base_channels = base_channels
         self.num_stages = num_stages
         assert num_stages >= 1 and num_stages <= 4
         self.strides = strides
@@ -382,6 +327,8 @@ class ResNet(nn.Module):
         self.out_indices = out_indices
         assert max(out_indices) < num_stages
         self.style = style
+        self.deep_stem = deep_stem
+        self.avg_down = avg_down
         self.frozen_stages = frozen_stages
         self.conv_cfg = conv_cfg
         self.norm_cfg = norm_cfg
@@ -399,9 +346,9 @@ class ResNet(nn.Module):
         self.zero_init_residual = zero_init_residual
         self.block, stage_blocks = self.arch_settings[depth]
         self.stage_blocks = stage_blocks[:num_stages]
-        self.inplanes = 64
+        self.inplanes = base_channels
 
-        self._make_stem_layer(in_channels)
+        self._make_stem_layer(in_channels, base_channels)
 
         self.res_layers = []
         for i, num_blocks in enumerate(self.stage_blocks):
@@ -409,15 +356,16 @@ class ResNet(nn.Module):
             dilation = dilations[i]
             dcn = self.dcn if self.stage_with_dcn[i] else None
             gcb = self.gcb if self.stage_with_gcb[i] else None
-            planes = 64 * 2**i
-            res_layer = make_res_layer(
-                self.block,
-                self.inplanes,
-                planes,
-                num_blocks,
+            planes = base_channels * 2**i
+            res_layer = self.make_res_layer(
+                block=self.block,
+                inplanes=self.inplanes,
+                planes=planes,
+                num_blocks=num_blocks,
                 stride=stride,
                 dilation=dilation,
                 style=self.style,
+                avg_down=self.avg_down,
                 with_cp=with_cp,
                 conv_cfg=conv_cfg,
                 norm_cfg=norm_cfg,
@@ -432,33 +380,75 @@ class ResNet(nn.Module):
 
         self._freeze_stages()
 
-        self.feat_dim = self.block.expansion * 64 * 2**(
+        self.feat_dim = self.block.expansion * base_channels * 2**(
             len(self.stage_blocks) - 1)
 
+    def make_res_layer(self, **kwargs):
+        return ResLayer(**kwargs)
+
     @property
     def norm1(self):
         return getattr(self, self.norm1_name)
 
-    def _make_stem_layer(self, in_channels):
-        self.conv1 = build_conv_layer(
-            self.conv_cfg,
-            in_channels,
-            64,
-            kernel_size=7,
-            stride=2,
-            padding=3,
-            bias=False)
-        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
-        self.add_module(self.norm1_name, norm1)
-        self.relu = nn.ReLU(inplace=True)
+    def _make_stem_layer(self, in_channels, base_channels):
+        if self.deep_stem:
+            self.stem = nn.Sequential(
+                build_conv_layer(
+                    self.conv_cfg,
+                    in_channels,
+                    base_channels // 2,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, base_channels // 2)[1],
+                nn.ReLU(inplace=True),
+                build_conv_layer(
+                    self.conv_cfg,
+                    base_channels // 2,
+                    base_channels // 2,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, base_channels // 2)[1],
+                nn.ReLU(inplace=True),
+                build_conv_layer(
+                    self.conv_cfg,
+                    base_channels // 2,
+                    base_channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, base_channels)[1],
+                nn.ReLU(inplace=True))
+        else:
+            self.conv1 = build_conv_layer(
+                self.conv_cfg,
+                in_channels,
+                base_channels,
+                kernel_size=7,
+                stride=2,
+                padding=3,
+                bias=False)
+            self.norm1_name, norm1 = build_norm_layer(
+                self.norm_cfg, base_channels, postfix=1)
+            self.add_module(self.norm1_name, norm1)
+            self.relu = nn.ReLU(inplace=True)
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 
     def _freeze_stages(self):
         if self.frozen_stages >= 0:
-            self.norm1.eval()
-            for m in [self.conv1, self.norm1]:
-                for param in m.parameters():
+            if self.deep_stem:
+                self.stem.eval()
+                for param in self.stem.parameters():
                     param.requires_grad = False
+            else:
+                self.norm1.eval()
+                for m in [self.conv1, self.norm1]:
+                    for param in m.parameters():
+                        param.requires_grad = False
 
         for i in range(1, self.frozen_stages + 1):
             m = getattr(self, 'layer{}'.format(i))
@@ -493,9 +483,12 @@ class ResNet(nn.Module):
             raise TypeError('pretrained must be a str or None')
 
     def forward(self, x):
-        x = self.conv1(x)
-        x = self.norm1(x)
-        x = self.relu(x)
+        if self.deep_stem:
+            x = self.stem(x)
+        else:
+            x = self.conv1(x)
+            x = self.norm1(x)
+            x = self.relu(x)
         x = self.maxpool(x)
         outs = []
         for i, layer_name in enumerate(self.res_layers):
@@ -513,3 +506,21 @@ class ResNet(nn.Module):
                 # trick: eval have effect on BatchNorm only
                 if isinstance(m, _BatchNorm):
                     m.eval()
+
+
+@BACKBONES.register_module
+class ResNetV1d(ResNet):
+    """ResNetV1d variant described in [1]_.
+
+    Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv
+    in the input stem with three 3x3 convs. And in the downsampling block,
+    a 2x2 avg_pool with stride 2 is added before conv, whose stride is
+    changed to 1.
+
+    References:
+        .. [1] https://arxiv.org/pdf/1812.01187.pdf
+    """
+
+    def __init__(self, **kwargs):
+        super(ResNetV1d, self).__init__(
+            deep_stem=True, avg_down=True, **kwargs)
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
index 69364f77915f06140b8ee290d1cdadf80e6fe71a..86fe14c08c1a39e0e11450e205233421ead65bc6 100644
--- a/mmdet/models/backbones/resnext.py
+++ b/mmdet/models/backbones/resnext.py
@@ -1,16 +1,21 @@
 import math
 
-import torch.nn as nn
-
 from mmdet.ops import build_conv_layer, build_norm_layer
 from ..registry import BACKBONES
+from ..utils import ResLayer
 from .resnet import Bottleneck as _Bottleneck
 from .resnet import ResNet
 
 
 class Bottleneck(_Bottleneck):
 
-    def __init__(self, inplanes, planes, groups=1, base_width=4, **kwargs):
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 groups=1,
+                 base_width=4,
+                 base_channels=64,
+                 **kwargs):
         """Bottleneck block for ResNeXt.
         If style is "pytorch", the stride-two layer is the 3x3 conv layer,
         if it is "caffe", the stride-two layer is the first 1x1 conv layer.
@@ -20,7 +25,8 @@ class Bottleneck(_Bottleneck):
         if groups == 1:
             width = self.planes
         else:
-            width = math.floor(self.planes * (base_width / 64)) * groups
+            width = math.floor(self.planes *
+                               (base_width / base_channels)) * groups
 
         self.norm1_name, norm1 = build_norm_layer(
             self.norm_cfg, width, postfix=1)
@@ -75,69 +81,6 @@ class Bottleneck(_Bottleneck):
         self.add_module(self.norm3_name, norm3)
 
 
-def make_res_layer(block,
-                   inplanes,
-                   planes,
-                   blocks,
-                   stride=1,
-                   dilation=1,
-                   groups=1,
-                   base_width=4,
-                   style='pytorch',
-                   with_cp=False,
-                   conv_cfg=None,
-                   norm_cfg=dict(type='BN'),
-                   dcn=None,
-                   gcb=None):
-    downsample = None
-    if stride != 1 or inplanes != planes * block.expansion:
-        downsample = nn.Sequential(
-            build_conv_layer(
-                conv_cfg,
-                inplanes,
-                planes * block.expansion,
-                kernel_size=1,
-                stride=stride,
-                bias=False),
-            build_norm_layer(norm_cfg, planes * block.expansion)[1],
-        )
-
-    layers = []
-    layers.append(
-        block(
-            inplanes=inplanes,
-            planes=planes,
-            stride=stride,
-            dilation=dilation,
-            downsample=downsample,
-            groups=groups,
-            base_width=base_width,
-            style=style,
-            with_cp=with_cp,
-            conv_cfg=conv_cfg,
-            norm_cfg=norm_cfg,
-            dcn=dcn,
-            gcb=gcb))
-    inplanes = planes * block.expansion
-    for i in range(1, blocks):
-        layers.append(
-            block(
-                inplanes=inplanes,
-                planes=planes,
-                stride=1,
-                dilation=dilation,
-                groups=groups,
-                base_width=base_width,
-                style=style,
-                with_cp=with_cp,
-                conv_cfg=conv_cfg,
-                norm_cfg=norm_cfg,
-                dcn=dcn,
-                gcb=gcb))
-
-    return nn.Sequential(*layers)
-
-
 @BACKBONES.register_module
 class ResNeXt(ResNet):
     """ResNeXt backbone.
@@ -187,36 +130,13 @@ class ResNeXt(ResNet):
     }
 
     def __init__(self, groups=1, base_width=4, **kwargs):
-        super(ResNeXt, self).__init__(**kwargs)
         self.groups = groups
         self.base_width = base_width
+        super(ResNeXt, self).__init__(**kwargs)
 
-        self.inplanes = 64
-        self.res_layers = []
-        for i, num_blocks in enumerate(self.stage_blocks):
-            stride = self.strides[i]
-            dilation = self.dilations[i]
-            dcn = self.dcn if self.stage_with_dcn[i] else None
-            gcb = self.gcb if self.stage_with_gcb[i] else None
-            planes = 64 * 2**i
-            res_layer = make_res_layer(
-                self.block,
-                self.inplanes,
-                planes,
-                num_blocks,
-                stride=stride,
-                dilation=dilation,
-                groups=self.groups,
-                base_width=self.base_width,
-                style=self.style,
-                with_cp=self.with_cp,
-                conv_cfg=self.conv_cfg,
-                norm_cfg=self.norm_cfg,
-                dcn=dcn,
-                gcb=gcb)
-            self.inplanes = planes * self.block.expansion
-            layer_name = 'layer{}'.format(i + 1)
-            self.add_module(layer_name, res_layer)
-            self.res_layers.append(layer_name)
-
-        self._freeze_stages()
+    def make_res_layer(self, **kwargs):
+        return ResLayer(
+            groups=self.groups,
+            base_width=self.base_width,
+            base_channels=self.base_channels,
+            **kwargs)
diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py
index e1a1ba0d76b34d6199ba397916e6b29ade8e0a74..a5178c334243a6da5f09edd77ec8104a8b6c1e3a 100644
--- a/mmdet/models/shared_heads/res_layer.py
+++ b/mmdet/models/shared_heads/res_layer.py
@@ -4,8 +4,9 @@ from mmcv.runner import load_checkpoint
 
 from mmdet.core import auto_fp16
 from mmdet.utils import get_root_logger
-from ..backbones import ResNet, make_res_layer
+from ..backbones import ResNet
 from ..registry import SHARED_HEADS
+from ..utils import ResLayer as _ResLayer
 
 
 @SHARED_HEADS.register_module
@@ -31,7 +32,7 @@ class ResLayer(nn.Module):
         planes = 64 * 2**stage
         inplanes = 64 * 2**(stage - 1) * block.expansion
 
-        res_layer = make_res_layer(
+        res_layer = _ResLayer(
             block,
             inplanes,
             planes,
diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py
index 83614f71954ca16f2f1dee91d94845e335d21a80..47213871f05e1656eb7f09fd986add8135ef1cff 100644
--- a/mmdet/models/utils/__init__.py
+++ b/mmdet/models/utils/__init__.py
@@ -1,3 +1,4 @@
+from .res_layer import ResLayer
 from .weight_init import bias_init_with_prob
 
-__all__ = ['bias_init_with_prob']
+__all__ = ['bias_init_with_prob', 'ResLayer']
diff --git a/mmdet/models/utils/res_layer.py b/mmdet/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b8c61f285cce027b6349d2b167d2c2c7b32d141
--- /dev/null
+++ b/mmdet/models/utils/res_layer.py
@@ -0,0 +1,85 @@
+from torch import nn as nn
+
+from mmdet.ops import build_conv_layer, build_norm_layer
+
+
+class ResLayer(nn.Sequential):
+    """ResLayer to build ResNet style backbone.
+
+    Args:
+        block (nn.Module): block used to build ResLayer.
+        inplanes (int): inplanes of block.
+        planes (int): planes of block.
+        num_blocks (int): number of blocks.
+        stride (int): stride of the first block. Default: 1
+        avg_down (bool): Use AvgPool instead of stride conv when
+            downsampling in the bottleneck. Default: False
+        conv_cfg (dict): dictionary to construct and config conv layer.
+            Default: None
+        norm_cfg (dict): dictionary to construct and config norm layer.
+            Default: dict(type='BN')
+    """
+
+    def __init__(self,
+                 block,
+                 inplanes,
+                 planes,
+                 num_blocks,
+                 stride=1,
+                 avg_down=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 **kwargs):
+        self.block = block
+
+        downsample = None
+        if stride != 1 or inplanes != planes * block.expansion:
+            downsample = []
+            conv_stride = stride
+            if avg_down and stride != 1:
+                conv_stride = 1
+                downsample.append(
+                    nn.AvgPool2d(
+                        kernel_size=stride,
+                        stride=stride,
+                        ceil_mode=True,
+                        count_include_pad=False))
+            downsample.extend([
+                build_conv_layer(
+                    conv_cfg,
+                    inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=conv_stride,
+                    bias=False),
+                build_norm_layer(norm_cfg, planes * block.expansion)[1]
+            ])
+            downsample = nn.Sequential(*downsample)
+
+        gen_attention = kwargs.pop('gen_attention', None)
+        gen_attention_blocks = kwargs.pop('gen_attention_blocks', tuple())
+        layers = []
+        layers.append(
+            block(
+                inplanes=inplanes,
+                planes=planes,
+                stride=stride,
+                downsample=downsample,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                gen_attention=gen_attention
+                if 0 in gen_attention_blocks else None,
+                **kwargs))
+        inplanes = planes * block.expansion
+        for i in range(1, num_blocks):
+            layers.append(
+                block(
+                    inplanes=inplanes,
+                    planes=planes,
+                    stride=1,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    gen_attention=gen_attention if
+                    (i in gen_attention_blocks) else None,
+                    **kwargs))
+        super(ResLayer, self).__init__(*layers)
diff --git a/tests/test_backbone.py b/tests/test_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f8ab685a6fc9a1ac47f5ca85583f1f3059fc6e3
--- /dev/null
+++ b/tests/test_backbone.py
@@ -0,0 +1,401 @@
+import pytest
+import torch
+from torch.nn.modules import AvgPool2d, GroupNorm
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmdet.models.backbones import ResNet, ResNetV1d, ResNeXt
+from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
+from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
+from mmdet.models.utils import ResLayer
+from mmdet.ops import DeformConvPack
+
+
+def is_block(modules):
+    """Check if is ResNet building block."""
+    if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)):
+        return True
+    return False
+
+
+def is_norm(modules):
+    """Check if is one of the norms."""
+    if isinstance(modules, (GroupNorm, _BatchNorm)):
+        return True
+    return False
+
+
+def all_zeros(modules):
+    weight_zero = torch.allclose(modules.weight.data,
+                                 torch.zeros_like(modules.weight.data))
+    if hasattr(modules, 'bias'):
+        bias_zero = torch.allclose(modules.bias.data,
+                                   torch.zeros_like(modules.bias.data))
+    else:
+        bias_zero = True
+
+    return weight_zero and bias_zero
+
+
+def check_norm_state(modules, train_state):
+    """Check if norm layer is in correct train state."""
+    for mod in modules:
+        if isinstance(mod, _BatchNorm):
+            if mod.training != train_state:
+                return False
+    return True
+
+
+def test_resnet_basic_block():
+
+    with pytest.raises(AssertionError):
+        BasicBlock(64, 64, with_cp=True)
+
+    with pytest.raises(AssertionError):
+        # Not implemented yet.
+        dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False)
+        BasicBlock(64, 64, dcn=dcn)
+
+    with pytest.raises(AssertionError):
+        # Not implemented yet.
+        gcb = dict(ratio=1. / 4., )
+        BasicBlock(64, 64, gcb=gcb)
+
+    with pytest.raises(AssertionError):
+        # Not implemented yet
+        gen_attention = dict(
+            spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2)
+        BasicBlock(64, 64, gen_attention=gen_attention)
+
+    block = BasicBlock(64, 64)
+    assert block.conv1.in_channels == 64
+    assert block.conv1.out_channels == 64
+    assert block.conv1.kernel_size == (3, 3)
+    assert block.conv2.in_channels == 64
+    assert block.conv2.out_channels == 64
+    assert block.conv2.kernel_size == (3, 3)
+    x = torch.randn(1, 64, 56, 56)
+    x_out = block(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+
+def test_resnet_bottleneck():
+
+    with pytest.raises(AssertionError):
+        # style must be in ['pytorch', 'caffe']
+        Bottleneck(64, 64, style='tensorflow')
+
+    block = Bottleneck(64, 16, with_cp=True)
+    assert block.with_cp
+    x = torch.randn(1, 64, 56, 56)
+    x_out = block(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+    block = Bottleneck(64, 64, stride=2, style='pytorch')
+    assert block.conv1.stride == (1, 1)
+    assert block.conv2.stride == (2, 2)
+    block = Bottleneck(64, 64, stride=2, style='caffe')
+    assert block.conv1.stride == (2, 2)
+    assert block.conv2.stride == (1, 1)
+
+    dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False)
+    with pytest.raises(AssertionError):
+        Bottleneck(64, 64, dcn=dcn, conv_cfg=dict(type='Conv'))
+    block = Bottleneck(64, 64, dcn=dcn)
+    assert isinstance(block.conv2, DeformConvPack)
+
+    block = Bottleneck(64, 16)
+    x = torch.randn(1, 64, 56, 56)
+    x_out = block(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+    gcb = dict(ratio=1. / 4., )
+    block = Bottleneck(64, 16, gcb=gcb)
+    assert hasattr(block, 'context_block')
+    x = torch.randn(1, 64, 56, 56)
+    x_out = block(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+    gen_attention = dict(
+        spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2)
+    block = Bottleneck(64, 16, gen_attention=gen_attention)
+    assert hasattr(block, 'gen_attention')
+    x = torch.randn(1, 64, 56, 56)
+    x_out = block(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+
+def test_resnet_res_layer():
+    layer = ResLayer(Bottleneck, 64, 16, 3)
+    assert len(layer) == 3
+    assert layer[0].conv1.in_channels == 64
+    assert layer[0].conv1.out_channels == 16
+    for i in range(1, len(layer)):
+        assert layer[i].conv1.in_channels == 64
+        assert layer[i].conv1.out_channels == 16
+    for i in range(len(layer)):
+        assert layer[i].downsample is None
+    x = torch.randn(1, 64, 56, 56)
+    x_out = layer(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+    layer = ResLayer(Bottleneck, 64, 64, 3)
+    assert layer[0].downsample[0].out_channels == 256
+    for i in range(1, len(layer)):
+        assert layer[i].downsample is None
+    x = torch.randn(1, 64, 56, 56)
+    x_out = layer(x)
+    assert x_out.shape == torch.Size([1, 256, 56, 56])
+
+    layer = ResLayer(Bottleneck, 64, 64, 3, stride=2)
+    assert layer[0].downsample[0].out_channels == 256
+    assert layer[0].downsample[0].stride == (2, 2)
+    for i in range(1, len(layer)):
+        assert layer[i].downsample is None
+    x = torch.randn(1, 64, 56, 56)
+    x_out = layer(x)
+    assert x_out.shape == torch.Size([1, 256, 28, 28])
+
+    layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True)
+    assert isinstance(layer[0].downsample[0], AvgPool2d)
+    assert layer[0].downsample[1].out_channels == 256
+    assert layer[0].downsample[1].stride == (1, 1)
+    for i in range(1, len(layer)):
+        assert layer[i].downsample is None
+    x = torch.randn(1, 64, 56, 56)
+    x_out = layer(x)
+    assert x_out.shape == torch.Size([1, 256, 28, 28])
+
+
+def test_resnet_backbone():
+    """Test resnet backbone"""
+    with pytest.raises(KeyError):
+        # ResNet depth should be in [18, 34, 50, 101, 152]
+        ResNet(20)
+
+    with pytest.raises(AssertionError):
+        # In ResNet: 1 <= num_stages <= 4
+        ResNet(50, num_stages=0)
+
+    with pytest.raises(AssertionError):
+        ResNet(18, with_cp=True)
+
+    with pytest.raises(AssertionError):
+        # len(stage_with_dcn) == num_stages
+        dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False)
+        ResNet(50, dcn=dcn, stage_with_dcn=(True, ))
+
+    with pytest.raises(AssertionError):
+        # len(stage_with_gcb) == num_stages
+        gcb = dict(ratio=1. / 4., )
+        ResNet(50, gcb=gcb, stage_with_gcb=(True, ))
+
+    with pytest.raises(AssertionError):
+        # In ResNet: 1 <= num_stages <= 4
+        ResNet(50, num_stages=5)
+
+    with pytest.raises(AssertionError):
+        # len(strides) == len(dilations) == num_stages
+        ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
+
+    with pytest.raises(TypeError):
+        model = ResNet(50)
+        model.init_weights(pretrained=0)
+
+    with pytest.raises(AssertionError):
+        # style must be in ['pytorch', 'caffe']
+        ResNet(50, style='tensorflow')
+
+    with pytest.raises(AssertionError):
+        # assert not with_cp
+        ResNet(18, with_cp=True)
+
+    model = ResNet(18)
+    model.init_weights()
+
+    model = ResNet(50, norm_eval=True)
+    model.init_weights()
+    model.train()
+    assert check_norm_state(model.modules(), False)
+
+    model = ResNet(depth=50, norm_eval=True)
+    model.init_weights('torchvision://resnet50')
+    model.train()
+    assert check_norm_state(model.modules(), False)
+
+    frozen_stages = 1
+    model = ResNet(50, frozen_stages=frozen_stages)
+    model.init_weights()
+    model.train()
+    assert model.norm1.training is False
+    for layer in [model.conv1, model.norm1]:
+        for param in layer.parameters():
+            assert param.requires_grad is False
+    for i in range(1, frozen_stages + 1):
+        layer = getattr(model, 'layer{}'.format(i))
+        for mod in layer.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in layer.parameters():
+            assert param.requires_grad is False
+
+    model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
+    assert len(model.stem) == 9
+    model.init_weights()
+    model.train()
+    check_norm_state(model.stem, False)
+    for param in model.stem.parameters():
+        assert param.requires_grad is False
+    for i in range(1, frozen_stages + 1):
+        layer = getattr(model, 'layer{}'.format(i))
+        for mod in layer.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in layer.parameters():
+            assert param.requires_grad is False
+
+    model = ResNet(18)
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 64, 56, 56])
+    assert feat[1].shape == torch.Size([1, 128, 28, 28])
+    assert feat[2].shape == torch.Size([1, 256, 14, 14])
+    assert feat[3].shape == torch.Size([1, 512, 7, 7])
+
+    model = ResNet(50)
+    for m in model.modules():
+        if is_norm(m):
+            assert isinstance(m, _BatchNorm)
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+    assert feat[3].shape == torch.Size([1, 2048, 7, 7])
+
+    model = ResNet(50, out_indices=(0, 1, 2))
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 3
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+
+    model = ResNet(50, with_cp=True)
+    for m in model.modules():
+        if is_block(m):
+            assert m.with_cp
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+    assert feat[3].shape == torch.Size([1, 2048, 7, 7])
+
+    model = ResNet(
+        50, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))
+    for m in model.modules():
+        if is_norm(m):
+            assert isinstance(m, GroupNorm)
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+    assert feat[3].shape == torch.Size([1, 2048, 7, 7])
+
+    model = ResNet(50, zero_init_residual=True)
+    model.init_weights()
+    for m in model.modules():
+        if isinstance(m, Bottleneck):
+            assert all_zeros(m.norm3)
+        elif isinstance(m, BasicBlock):
+            assert all_zeros(m.norm2)
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+    assert feat[3].shape == torch.Size([1, 2048, 7, 7])
+
+    model = ResNetV1d(depth=50)
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+    assert feat[3].shape == torch.Size([1, 2048, 7, 7])
+
+
+def test_renext_bottleneck():
+    with pytest.raises(AssertionError):
+        # style must be in ['pytorch', 'caffe']
+        BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
+
+    block = BottleneckX(
+        64, 64, groups=32, base_width=4, stride=2, style='pytorch')
+    assert block.conv2.stride == (2, 2)
+    assert block.conv2.groups == 32
+    assert block.conv2.out_channels == 128
+
+    dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False)
+    with pytest.raises(AssertionError):
+        BottleneckX(
+            64,
+            64,
+            groups=32,
+            base_width=4,
+            dcn=dcn,
+            conv_cfg=dict(type='Conv'))
+    BottleneckX(64, 64, dcn=dcn)
+
+    block = BottleneckX(64, 16, groups=32, base_width=4)
+    x = torch.randn(1, 64, 56, 56)
+    x_out = block(x)
+    assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+
+def test_resnext_backbone():
+    with pytest.raises(KeyError):
+        # ResNeXt depth should be in [50, 101, 152]
+        ResNeXt(depth=18)
+
+    model = ResNeXt(depth=50, groups=32, base_width=4)
+    for m in model.modules():
+        if is_block(m):
+            assert m.conv2.groups == 32
+    model.init_weights()
+    model.train()
+
+    imgs = torch.randn(1, 3, 224, 224)
+    feat = model(imgs)
+    assert len(feat) == 4
+    assert feat[0].shape == torch.Size([1, 256, 56, 56])
+    assert feat[1].shape == torch.Size([1, 512, 28, 28])
+    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
+    assert feat[3].shape == torch.Size([1, 2048, 7, 7])