diff --git a/configs/swin/swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py b/configs/swin/swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..716085df42b439ffeec553628d61042d2bd797c5
--- /dev/null
+++ b/configs/swin/swin_base_patch4_window7_mstrain_480-800_adamw_3x_coco.py
@@ -0,0 +1,91 @@
+_base_ = [
+    '../_base_/datasets/coco_instance.py',
+    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+
+model = dict(
+    type='NoneDetector',
+    pretrained=None,
+    backbone=dict(
+        type='SwinTransformer',
+        embed_dim=128,
+        depths=[2, 2, 18, 2],
+        num_heads=[4, 8, 16, 32],
+        window_size=7,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=0.3,
+        ape=False,
+        patch_norm=True,
+        out_indices=(0, 1, 2, 3),
+        use_checkpoint=False),
+    
+    # model training and testing settings
+    train_cfg=None,
+    test_cfg=None)
+
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='AutoAugment',
+         policies=[
+             [
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+                                 (736, 1333), (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True)
+             ],
+             [
+                 dict(type='Resize',
+                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True),
+                 dict(type='RandomCrop',
+                      crop_type='absolute_range',
+                      crop_size=(384, 600),
+                      allow_negative_crop=True),
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333),
+                                 (576, 1333), (608, 1333), (640, 1333),
+                                 (672, 1333), (704, 1333), (736, 1333),
+                                 (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      override=True,
+                      keep_ratio=True)
+             ]
+         ]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[27, 33])
+runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
+
+# do not use mmdet version fp16
+fp16 = None
+optimizer_config = dict(
+    type="DistOptimizerHook",
+    update_interval=1,
+    grad_clip=None,
+    coalesce=True,
+    bucket_size_mb=-1,
+    use_fp16=True,
+)
diff --git a/configs/swin/swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py b/configs/swin/swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d9e53360bdf58f490679bb230526bcd5d1915b
--- /dev/null
+++ b/configs/swin/swin_small_patch4_window7_mstrain_480-800_adamw_3x_coco.py
@@ -0,0 +1,90 @@
+_base_ = [
+    '../_base_/datasets/coco_instance.py',
+    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+    type='NoneDetector',
+    pretrained=None,
+    backbone=dict(
+        type='SwinTransformer',
+        embed_dim=96,
+        depths=[2, 2, 18, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=0.2,
+        ape=False,
+        patch_norm=True,
+        out_indices=(0, 1, 2, 3),
+        use_checkpoint=False),
+    
+    # model training and testing settings
+    train_cfg=None,
+    test_cfg=None)
+
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='AutoAugment',
+         policies=[
+             [
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+                                 (736, 1333), (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True)
+             ],
+             [
+                 dict(type='Resize',
+                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True),
+                 dict(type='RandomCrop',
+                      crop_type='absolute_range',
+                      crop_size=(384, 600),
+                      allow_negative_crop=True),
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333),
+                                 (576, 1333), (608, 1333), (640, 1333),
+                                 (672, 1333), (704, 1333), (736, 1333),
+                                 (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      override=True,
+                      keep_ratio=True)
+             ]
+         ]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[27, 33])
+runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
+
+# do not use mmdet version fp16
+fp16 = None
+optimizer_config = dict(
+    type="DistOptimizerHook",
+    update_interval=1,
+    grad_clip=None,
+    coalesce=True,
+    bucket_size_mb=-1,
+    use_fp16=True,
+)
diff --git a/configs/swin/swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py b/configs/swin/swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f5b229c0657b9461b039275d1c84fd2127883b
--- /dev/null
+++ b/configs/swin/swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py
@@ -0,0 +1,91 @@
+_base_ = [
+    '../_base_/datasets/coco_instance.py',
+    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+    type='NoneDetector',
+    pretrained=None,
+    backbone=dict(
+        type='SwinTransformer',
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=0.1,
+        ape=False,
+        patch_norm=True,
+        out_indices=(0, 1, 2, 3),
+        use_checkpoint=False),
+    
+    # model training and testing settings
+    train_cfg=None,
+    test_cfg=None)
+
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='AutoAugment',
+         policies=[
+             [
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+                                 (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+                                 (736, 1333), (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True)
+             ],
+             [
+                 dict(type='Resize',
+                      img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+                      multiscale_mode='value',
+                      keep_ratio=True),
+                 dict(type='RandomCrop',
+                      crop_type='absolute_range',
+                      crop_size=(384, 600),
+                      allow_negative_crop=True),
+                 dict(type='Resize',
+                      img_scale=[(480, 1333), (512, 1333), (544, 1333),
+                                 (576, 1333), (608, 1333), (640, 1333),
+                                 (672, 1333), (704, 1333), (736, 1333),
+                                 (768, 1333), (800, 1333)],
+                      multiscale_mode='value',
+                      override=True,
+                      keep_ratio=True)
+             ]
+         ]),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[8, 11])
+#runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+runner = dict(type='EpochBasedRunner', max_epochs=12)
+ 
+# do not use mmdet version fp16
+# fp16 = None
+# optimizer_config = dict(
+#     type="DistOptimizerHook",
+#     update_interval=1,
+#     grad_clip=None,
+#     coalesce=True,
+#     bucket_size_mb=-1,
+#     use_fp16=True,
+# )
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index 04011130435cf9fdfadeb821919046b1bddab7d4..8a9635217ef795fd6dae7b6d7724c23a27b0f8a0 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -28,6 +28,7 @@ from .two_stage import TwoStageDetector
 from .vfnet import VFNet
 from .yolact import YOLACT
 from .yolo import YOLOV3
+from .none import NoneDetector
 
 __all__ = [
     'ATSS', 'BaseDetector', 'SingleStageDetector',
@@ -36,5 +37,5 @@ __all__ = [
     'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
     'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
     'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN',
-    'SCNet'
+    'SCNet', 'NoneDetector'
 ]
diff --git a/mmdet/models/detectors/none.py b/mmdet/models/detectors/none.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a3e4316c4dc82ce8d4d6b14118bda15108162c
--- /dev/null
+++ b/mmdet/models/detectors/none.py
@@ -0,0 +1,149 @@
+import torch
+import torch.nn as nn
+
+# from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class NoneDetector(BaseDetector):
+    """class for none detectors.
+
+    None detectors only extract backbone features.
+    """
+
+    def __init__(self,
+                 backbone,
+                 neck=None,
+                 train_cfg=None,
+                 test_cfg=None,
+                 pretrained=None):
+        super(NoneDetector, self).__init__()
+        self.backbone = build_backbone(backbone)
+
+        if neck is not None:
+            self.neck = build_neck(neck)
+
+        self.train_cfg = train_cfg
+        self.test_cfg = test_cfg
+
+        self.init_weights(pretrained=pretrained)
+
+    @property
+    def with_rpn(self):
+        """bool: whether the detector has RPN"""
+        return hasattr(self, 'rpn_head') and self.rpn_head is not None
+
+    @property
+    def with_roi_head(self):
+        """bool: whether the detector has a RoI head"""
+        return hasattr(self, 'roi_head') and self.roi_head is not None
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in detector.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        super(NoneDetector, self).init_weights(pretrained)
+        self.backbone.init_weights(pretrained=pretrained)
+        if self.with_neck:
+            if isinstance(self.neck, nn.Sequential):
+                for m in self.neck:
+                    m.init_weights()
+            else:
+                self.neck.init_weights()
+        if self.with_rpn:
+            self.rpn_head.init_weights()
+        if self.with_roi_head:
+            self.roi_head.init_weights(pretrained)
+
+    def extract_feat(self, img):
+        """Directly extract features from the backbone+neck."""
+        x = self.backbone(img)
+        if self.with_neck:
+            x = self.neck(x)
+        return x
+
+    def forward_dummy(self, img):
+        """Used for computing network flops.
+
+        See `mmdetection/tools/analysis_tools/get_flops.py`
+        """
+        # backbone
+        x = self.extract_feat(img)
+        return x
+
+    def forward_train(self,
+                      img,
+                      img_metas,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      proposals=None,
+                      **kwargs):
+        """
+        Args:
+            img (Tensor): of shape (N, C, H, W) encoding input images.
+                Typically these should be mean centered and std scaled.
+
+            img_metas (list[dict]): list of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+
+            gt_labels (list[Tensor]): class indices corresponding to each box
+
+            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+                boxes can be ignored when computing the loss.
+
+            gt_masks (None | Tensor) : true segmentation masks for each box
+                used if the architecture supports a segmentation task.
+
+            proposals : override rpn proposals with custom proposals. Use when
+                `with_rpn` is False.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        x = self.extract_feat(img)
+
+        losses = dict()
+
+        return losses
+
+    async def async_simple_test(self,
+                                img,
+                                img_meta,
+                                proposals=None,
+                                rescale=False):
+        """Async test without augmentation."""
+        x = self.extract_feat(img)
+        return x
+
+    def simple_test(self, img, img_metas, proposals=None, rescale=False):
+        """Test without augmentation."""
+        x = self.extract_feat(img)
+
+        # get origin input shape to onnx dynamic input shape
+        if torch.onnx.is_in_onnx_export():
+            img_shape = torch._shape_as_tensor(img)[2:]
+            img_metas[0]['img_shape_for_onnx'] = img_shape
+
+        return x
+
+    def aug_test(self, imgs, img_metas, rescale=False):
+        """Test with augmentations.
+
+        If rescale is False, then returned bboxes and masks will fit the scale
+        of imgs[0].
+        """
+        x = self.extract_feats(imgs)
+        return x