diff --git a/configs/_base_/models/retinanet_r50_fpn.py b/configs/_base_/models/retinanet_r50_fpn.py
index dfb70a4356c9b809e0a9f1b995c3ff911b0e1c23..f51f0863cedde51c1d280d3d3b59365f99608ab0 100644
--- a/configs/_base_/models/retinanet_r50_fpn.py
+++ b/configs/_base_/models/retinanet_r50_fpn.py
@@ -16,7 +16,7 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         start_level=1,
-        add_extra_convs=True,
+        add_extra_convs='on_input',
         num_outs=5),
     bbox_head=dict(
         type='RetinaHead',
diff --git a/configs/atss/atss_r50_fpn_1x_coco.py b/configs/atss/atss_r50_fpn_1x_coco.py
index 134ba07bf11c5be7b87d32a80a7fb69499f63511..f359f0bb9b4cc06f283c54c5de7e0eefc9eae99f 100644
--- a/configs/atss/atss_r50_fpn_1x_coco.py
+++ b/configs/atss/atss_r50_fpn_1x_coco.py
@@ -19,8 +19,7 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         start_level=1,
-        add_extra_convs=True,
-        extra_convs_on_inputs=False,
+        add_extra_convs='on_output',
         num_outs=5),
     bbox_head=dict(
         type='ATSSHead',
diff --git a/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py b/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py
index e44982df69ccf50e085c191f91d15b570a8eaebc..59203d632d93188fcce13fced25444f92a4409c0 100644
--- a/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py
+++ b/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py
@@ -20,8 +20,7 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         start_level=1,
-        add_extra_convs=True,
-        extra_convs_on_inputs=False,  # use P5
+        add_extra_convs='on_output',  # use P5
         num_outs=5,
         relu_before_extra_convs=True),
     bbox_head=dict(
diff --git a/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py b/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py
index 6c7f3e5dd9e8f0fd8fc75dc26f9383c2c24ca0d1..a432780ab5c5adfe6e71d06e47d490e44e38cddf 100644
--- a/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py
+++ b/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py
@@ -20,8 +20,7 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         start_level=1,
-        add_extra_convs=True,
-        extra_convs_on_inputs=False,  # use P5
+        add_extra_convs='on_output',  # use P5
         num_outs=5,
         relu_before_extra_convs=True),
     bbox_head=dict(
diff --git a/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py b/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py
index 099204d4dc35163f91fe610b04e48091bdc7b4b4..9bafedd3d91ecb8010fc0d781cbc713a3b5aff45 100644
--- a/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py
+++ b/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py
@@ -21,7 +21,7 @@ model = dict(
         out_channels=256,
         start_level=1,
         num_outs=5,
-        add_extra_convs=True),
+        add_extra_convs='on_input'),
     bbox_head=dict(
         type='FoveaHead',
         num_classes=80,
diff --git a/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py b/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py
index d590969b644a15c8cf502be62d07adff357069b4..955b27a91daa56da028179454269979bf3e32747 100644
--- a/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py
+++ b/configs/guided_anchoring/ga_retinanet_r101_caffe_fpn_mstrain_2x.py
@@ -16,7 +16,7 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         start_level=1,
-        add_extra_convs=True,
+        add_extra_convs='on_input',
         num_outs=5),
     bbox_head=dict(
         type='GARetinaHead',
diff --git a/configs/libra_rcnn/libra_retinanet_r50_fpn_1x_coco.py b/configs/libra_rcnn/libra_retinanet_r50_fpn_1x_coco.py
index 174be9f5961df166dbbfa009f693f7c1934b14be..be2742098fb8f1e46bbb16c9d3e2e20c2e3083aa 100644
--- a/configs/libra_rcnn/libra_retinanet_r50_fpn_1x_coco.py
+++ b/configs/libra_rcnn/libra_retinanet_r50_fpn_1x_coco.py
@@ -7,7 +7,7 @@ model = dict(
             in_channels=[256, 512, 1024, 2048],
             out_channels=256,
             start_level=1,
-            add_extra_convs=True,
+            add_extra_convs='on_input',
             num_outs=5),
         dict(
             type='BFP',
diff --git a/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py b/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py
index 34e82f4c8f2f134bd41c04035ed807ac4f7c7761..a1a0c23b1e2944037f5f8b72aa2159c2cd871a1c 100644
--- a/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py
+++ b/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py
@@ -19,7 +19,7 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         start_level=1,
-        add_extra_convs=True,
+        add_extra_convs='on_input',
         num_outs=5),
     bbox_head=dict(
         type='RepPointsHead',
diff --git a/demo/mmdet_inference_colab.ipynb b/demo/mmdet_inference_colab.ipynb
index f85be9f76223fb6d987869f75a75ec59dfa3f888..82130323d1b89731df628b2d829ef5575c024906 100644
--- a/demo/mmdet_inference_colab.ipynb
+++ b/demo/mmdet_inference_colab.ipynb
@@ -534,4 +534,4 @@
       ]
     }
   ]
-}
\ No newline at end of file
+}
diff --git a/docs/api.rst b/docs/api.rst
index 3b531f1cf54e3feedbbf944db7890a88043eea0f..fe826d54213812efedf417be0d85e817f568a565 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -75,6 +75,11 @@ backbones
 .. automodule:: mmdet.models.backbones
     :members:
 
+necks
+^^^^^^^^^^^^
+.. automodule:: mmdet.models.necks
+    :members:
+
 dense_heads
 ^^^^^^^^^^^^
 .. automodule:: mmdet.models.dense_heads
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
index da508652885bfc0d9a7b677b3cfa1599bb3adad8..7c0f901abb21a441ae7a142cedf2eda857bdc8b5 100644
--- a/mmdet/models/necks/fpn.py
+++ b/mmdet/models/necks/fpn.py
@@ -22,10 +22,19 @@ class FPN(nn.Module):
             build the feature pyramid. Default: 0.
         end_level (int): Index of the end input backbone level (exclusive) to
             build the feature pyramid. Default: -1, which means the last level.
-        add_extra_convs (bool): Whether to add conv layers on top of the
-            original feature maps. Default: False.
-        extra_convs_on_inputs (bool): Whether to apply extra conv on
-            the original feature from the backbone. Default: False.
+        add_extra_convs (bool | str): If bool, it decides whether to add conv
+            layers on top of the original feature maps. Default to False.
+            If True, its actual mode is specified by `extra_convs_on_inputs`.
+            If str, it specifies the source feature map of the extra convs.
+            Only the following options are allowed
+
+            - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+            - 'on_lateral':  Last feature map after lateral convs.
+            - 'on_output': The last output feature map after fpn convs.
+        extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+            on the original feature from the backbone. If True,
+            it is equivalent to `add_extra_convs='on_input'`. If False, it is
+            equivalent to set `add_extra_convs='on_output'`. Default to True.
         relu_before_extra_convs (bool): Whether to apply relu before the extra
             conv. Default: False.
         no_norm_on_lateral (bool): Whether to apply norm on lateral.
@@ -34,6 +43,8 @@ class FPN(nn.Module):
         norm_cfg (dict): Config dict for normalization layer. Default: None.
         act_cfg (str): Config dict for activation layer in ConvModule.
             Default: None.
+        upsample_cfg (dict): Config dict for interpolate layer.
+            Default: `dict(mode='nearest')`
 
     Example:
         >>> import torch
@@ -63,7 +74,8 @@ class FPN(nn.Module):
                  no_norm_on_lateral=False,
                  conv_cfg=None,
                  norm_cfg=None,
-                 act_cfg=None):
+                 act_cfg=None,
+                 upsample_cfg=dict(mode='nearest')):
         super(FPN, self).__init__()
         assert isinstance(in_channels, list)
         self.in_channels = in_channels
@@ -73,6 +85,7 @@ class FPN(nn.Module):
         self.relu_before_extra_convs = relu_before_extra_convs
         self.no_norm_on_lateral = no_norm_on_lateral
         self.fp16_enabled = False
+        self.upsample_cfg = upsample_cfg.copy()
 
         if end_level == -1:
             self.backbone_end_level = self.num_ins
@@ -85,7 +98,17 @@ class FPN(nn.Module):
         self.start_level = start_level
         self.end_level = end_level
         self.add_extra_convs = add_extra_convs
-        self.extra_convs_on_inputs = extra_convs_on_inputs
+        assert isinstance(add_extra_convs, (str, bool))
+        if isinstance(add_extra_convs, str):
+            # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+            assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+        elif add_extra_convs:  # True
+            if extra_convs_on_inputs:
+                # For compatibility with previous release
+                # TODO: deprecate `extra_convs_on_inputs`
+                self.add_extra_convs = 'on_input'
+            else:
+                self.add_extra_convs = 'on_output'
 
         self.lateral_convs = nn.ModuleList()
         self.fpn_convs = nn.ModuleList()
@@ -114,9 +137,9 @@ class FPN(nn.Module):
 
         # add extra conv layers (e.g., RetinaNet)
         extra_levels = num_outs - self.backbone_end_level + self.start_level
-        if add_extra_convs and extra_levels >= 1:
+        if self.add_extra_convs and extra_levels >= 1:
             for i in range(extra_levels):
-                if i == 0 and self.extra_convs_on_inputs:
+                if i == 0 and self.add_extra_convs == 'on_input':
                     in_channels = self.in_channels[self.backbone_end_level - 1]
                 else:
                     in_channels = out_channels
@@ -151,9 +174,15 @@ class FPN(nn.Module):
         # build top-down path
         used_backbone_levels = len(laterals)
         for i in range(used_backbone_levels - 1, 0, -1):
-            prev_shape = laterals[i - 1].shape[2:]
-            laterals[i - 1] += F.interpolate(
-                laterals[i], size=prev_shape, mode='nearest')
+            # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+            #  it cannot co-exist with `size` in `F.interpolate`.
+            if 'scale_factor' in self.upsample_cfg:
+                laterals[i - 1] += F.interpolate(laterals[i],
+                                                 **self.upsample_cfg)
+            else:
+                prev_shape = laterals[i - 1].shape[2:]
+                laterals[i - 1] += F.interpolate(
+                    laterals[i], size=prev_shape, **self.upsample_cfg)
 
         # build outputs
         # part 1: from original levels
@@ -169,11 +198,15 @@ class FPN(nn.Module):
                     outs.append(F.max_pool2d(outs[-1], 1, stride=2))
             # add conv layers on top of original feature maps (RetinaNet)
             else:
-                if self.extra_convs_on_inputs:
-                    orig = inputs[self.backbone_end_level - 1]
-                    outs.append(self.fpn_convs[used_backbone_levels](orig))
+                if self.add_extra_convs == 'on_input':
+                    extra_source = inputs[self.backbone_end_level - 1]
+                elif self.add_extra_convs == 'on_lateral':
+                    extra_source = laterals[-1]
+                elif self.add_extra_convs == 'on_output':
+                    extra_source = outs[-1]
                 else:
-                    outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
+                    raise NotImplementedError
+                outs.append(self.fpn_convs[used_backbone_levels](extra_source))
                 for i in range(used_backbone_levels + 1, self.num_outs):
                     if self.relu_before_extra_convs:
                         outs.append(self.fpn_convs[i](F.relu(outs[-1])))
diff --git a/tests/test_necks.py b/tests/test_necks.py
new file mode 100644
index 0000000000000000000000000000000000000000..528bfb90aa95e01ecbc39ca70cd7e84c3317fb23
--- /dev/null
+++ b/tests/test_necks.py
@@ -0,0 +1,201 @@
+import pytest
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmdet.models.necks import FPN
+
+
+def test_fpn():
+    """Tests fpn """
+    s = 64
+    in_channels = [8, 16, 32, 64]
+    feat_sizes = [s // 2**i for i in range(4)]  # [64, 32, 16, 8]
+    out_channels = 8
+    # `num_outs` is not equal to len(in_channels) - start_level
+    with pytest.raises(AssertionError):
+        FPN(in_channels=in_channels,
+            out_channels=out_channels,
+            start_level=1,
+            num_outs=2)
+
+    # `end_level` is larger than len(in_channels) - 1
+    with pytest.raises(AssertionError):
+        FPN(in_channels=in_channels,
+            out_channels=out_channels,
+            start_level=1,
+            end_level=4,
+            num_outs=2)
+
+    # `num_outs` is not equal to end_level - start_level
+    with pytest.raises(AssertionError):
+        FPN(in_channels=in_channels,
+            out_channels=out_channels,
+            start_level=1,
+            end_level=3,
+            num_outs=1)
+
+    # Invalid `add_extra_convs` option
+    with pytest.raises(AssertionError):
+        FPN(in_channels=in_channels,
+            out_channels=out_channels,
+            start_level=1,
+            add_extra_convs='on_xxx',
+            num_outs=5)
+
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        start_level=1,
+        add_extra_convs=True,
+        num_outs=5)
+
+    # FPN expects a multiple levels of features per image
+    feats = [
+        torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
+        for i in range(len(in_channels))
+    ]
+    outs = fpn_model(feats)
+    assert fpn_model.add_extra_convs == 'on_input'
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # Tests for fpn with no extra convs (pooling is used instead)
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        start_level=1,
+        add_extra_convs=False,
+        num_outs=5)
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    assert not fpn_model.add_extra_convs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # Tests for fpn with lateral bns
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        start_level=1,
+        add_extra_convs=True,
+        no_norm_on_lateral=False,
+        norm_cfg=dict(type='BN', requires_grad=True),
+        num_outs=5)
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    assert fpn_model.add_extra_convs == 'on_input'
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+    bn_exist = False
+    for m in fpn_model.modules():
+        if isinstance(m, _BatchNorm):
+            bn_exist = True
+    assert bn_exist
+
+    # Bilinear upsample
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        start_level=1,
+        add_extra_convs=True,
+        upsample_cfg=dict(mode='bilinear', align_corners=True),
+        num_outs=5)
+    fpn_model(feats)
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    assert fpn_model.add_extra_convs == 'on_input'
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # Scale factor instead of fixed upsample size upsample
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        start_level=1,
+        add_extra_convs=True,
+        upsample_cfg=dict(scale_factor=2),
+        num_outs=5)
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # Extra convs source is 'inputs'
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        add_extra_convs='on_input',
+        start_level=1,
+        num_outs=5)
+    assert fpn_model.add_extra_convs == 'on_input'
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # Extra convs source is 'laterals'
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        add_extra_convs='on_lateral',
+        start_level=1,
+        num_outs=5)
+    assert fpn_model.add_extra_convs == 'on_lateral'
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # Extra convs source is 'outputs'
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        add_extra_convs='on_output',
+        start_level=1,
+        num_outs=5)
+    assert fpn_model.add_extra_convs == 'on_output'
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # extra_convs_on_inputs=False is equal to extra convs source is 'on_output'
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        add_extra_convs=True,
+        extra_convs_on_inputs=False,
+        start_level=1,
+        num_outs=5,
+    )
+    assert fpn_model.add_extra_convs == 'on_output'
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
+
+    # extra_convs_on_inputs=True is equal to extra convs source is 'on_input'
+    fpn_model = FPN(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        add_extra_convs=True,
+        extra_convs_on_inputs=True,
+        start_level=1,
+        num_outs=5,
+    )
+    assert fpn_model.add_extra_convs == 'on_input'
+    outs = fpn_model(feats)
+    assert len(outs) == fpn_model.num_outs
+    for i in range(fpn_model.num_outs):
+        outs[i].shape[1] == out_channels
+        outs[i].shape[2] == outs[i].shape[3] == s // (2**i)