diff --git a/README.md b/README.md
index e9dd0d5d009dd48c555c175dff149ffb50b1f001..cc8363b886b3905e9c01a6dbb6f0f4512874ffa7 100644
--- a/README.md
+++ b/README.md
@@ -76,6 +76,7 @@ Results and models are available in the [model zoo](docs/model_zoo.md).
 | PAFPN              | 鉁�        | 鉁�        | 鈽�        | 鉁�        | 鉁�     | 鈽�        | 鈽�     |
 | NAS-FCOS           | 鉁�        | 鉁�        | 鈽�        | 鉁�        | 鉁�     | 鈽�        | 鈽�     |
 | PISA               | 鉁�        | 鉁�        | 鈽�        | 鉁�        | 鉁�     | 鈽�        | 鈽�     |
+| Dynamic R-CNN      | 鉁�        | 鉁�        | 鈽�        | 鉁�        | 鉁�     | 鈽�        | 鈽�     |
 
 Other features
 - [x] [CARAFE](configs/carafe/README.md)
diff --git a/configs/dynamic_rcnn/README.md b/configs/dynamic_rcnn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f1b0bab957e780ebac59d5f069125b916cd9179c
--- /dev/null
+++ b/configs/dynamic_rcnn/README.md
@@ -0,0 +1,18 @@
+# Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training
+
+## Introduction
+
+```
+@article{DynamicRCNN,
+    author = {Hongkai Zhang and Hong Chang and Bingpeng Ma and Naiyan Wang and Xilin Chen},
+    title = {Dynamic {R-CNN}: Towards High Quality Object Detection via Dynamic Training},
+    journal = {arXiv preprint arXiv:2004.06002},
+    year = {2020}
+}
+```
+
+## Results and Models
+
+| Backbone  | Style   | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
+|:---------:|:-------:|:-------:|:--------:|:--------------:|:------:|:--------:|
+| R-50      | pytorch | 1x      | 3.8      |                |  38.9  | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x/dynamic_rcnn_r50_fpn_1x-62a3f276.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x/dynamic_rcnn_r50_fpn_1x_20200618_095048.log.json) |
diff --git a/configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py b/configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..60f9c5043a6d8e7da0c6038aca868ad7e966c534
--- /dev/null
+++ b/configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py
@@ -0,0 +1,28 @@
+_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
+model = dict(
+    roi_head=dict(
+        type='DynamicRoIHead',
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=80,
+            bbox_coder=dict(
+                type='DeltaXYWHBBoxCoder',
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.1, 0.1, 0.2, 0.2]),
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))))
+train_cfg = dict(
+    rpn_proposal=dict(nms_thr=0.85),
+    rcnn=dict(
+        dynamic_rcnn=dict(
+            iou_topk=75,
+            beta_topk=10,
+            update_iter_interval=100,
+            initial_iou=0.4,
+            initial_beta=1.0)))
+test_cfg = dict(rpn=dict(nms_thr=0.85))
diff --git a/docs/model_zoo.md b/docs/model_zoo.md
index f216f054536576b118e60e7ea2d0b1dfaf7c87c6..266ccac5a75ca14e39987b04d9bd8b7607cd47df 100644
--- a/docs/model_zoo.md
+++ b/docs/model_zoo.md
@@ -132,6 +132,9 @@ Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/
 ### GRoIE
 Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/configs/groie) for details.
 
+### Dynamic R-CNN
+Please refer to [Dynamic R-CNN](https://github.com/open-mmlab/mmdetection/blob/master/configs/dynamic_rcnn) for details.
+
 ### Other datasets
 
 We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
diff --git a/mmdet/models/roi_heads/__init__.py b/mmdet/models/roi_heads/__init__.py
index 0e05b5a57fbe1081248bb1df185ee6c021ac2f02..6d50b202684ed4a06fcdf51c287cf55c430a90ba 100644
--- a/mmdet/models/roi_heads/__init__.py
+++ b/mmdet/models/roi_heads/__init__.py
@@ -3,6 +3,7 @@ from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead,
                          Shared2FCBBoxHead, Shared4Conv1FCBBoxHead)
 from .cascade_roi_head import CascadeRoIHead
 from .double_roi_head import DoubleHeadRoIHead
+from .dynamic_roi_head import DynamicRoIHead
 from .grid_roi_head import GridRoIHead
 from .htc_roi_head import HybridTaskCascadeRoIHead
 from .mask_heads import (FCNMaskHead, FusedSemanticHead, GridHead, HTCMaskHead,
@@ -17,5 +18,6 @@ __all__ = [
     'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead',
     'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead',
     'DoubleConvFCBBoxHead', 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead',
-    'GridHead', 'MaskIoUHead', 'SingleRoIExtractor', 'PISARoIHead'
+    'GridHead', 'MaskIoUHead', 'SingleRoIExtractor', 'PISARoIHead',
+    'DynamicRoIHead'
 ]
diff --git a/mmdet/models/roi_heads/dynamic_roi_head.py b/mmdet/models/roi_heads/dynamic_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bcbd9a296a3278d2270c08effdc3500f5298141
--- /dev/null
+++ b/mmdet/models/roi_heads/dynamic_roi_head.py
@@ -0,0 +1,152 @@
+import numpy as np
+import torch
+
+from mmdet.core import bbox2roi
+from mmdet.models.losses import SmoothL1Loss
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class DynamicRoIHead(StandardRoIHead):
+    """RoI head for `Dynamic R-CNN <https://arxiv.org/abs/2004.06002>`_."""
+
+    def __init__(self, **kwargs):
+        super(DynamicRoIHead, self).__init__(**kwargs)
+        assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
+        self.iou_topk = self.train_cfg.dynamic_rcnn.iou_topk
+        self.beta_topk = self.train_cfg.dynamic_rcnn.beta_topk
+        self.update_iter_interval = \
+            self.train_cfg.dynamic_rcnn.update_iter_interval
+        # warm-up values for IoU and beta
+        self.initial_iou = self.train_cfg.dynamic_rcnn.initial_iou
+        self.initial_beta = self.train_cfg.dynamic_rcnn.initial_beta
+        # the IoU history of the past `update_iter_interval` iterations
+        self.iou_history = []
+        # the beta history of the past `update_iter_interval` iterations
+        self.beta_history = []
+
+    def forward_train(self,
+                      x,
+                      img_metas,
+                      proposal_list,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None):
+        """
+        Args:
+            x (list[Tensor]): list of multi-level img features.
+
+            img_metas (list[dict]): list of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+            proposals (list[Tensors]): list of region proposals.
+
+            gt_bboxes (list[Tensor]): each item are the truth boxes for each
+                image in [tl_x, tl_y, br_x, br_y] format.
+
+            gt_labels (list[Tensor]): class indices corresponding to each box
+
+            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+                boxes can be ignored when computing the loss.
+
+            gt_masks (None | Tensor) : true segmentation masks for each box
+                used if the architecture supports a segmentation task.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        # assign gts and sample proposals
+        if self.with_bbox or self.with_mask:
+            num_imgs = len(img_metas)
+            if gt_bboxes_ignore is None:
+                gt_bboxes_ignore = [None for _ in range(num_imgs)]
+            sampling_results = []
+            cur_iou = []
+            for i in range(num_imgs):
+                assign_result = self.bbox_assigner.assign(
+                    proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+                    gt_labels[i])
+                sampling_result = self.bbox_sampler.sample(
+                    assign_result,
+                    proposal_list[i],
+                    gt_bboxes[i],
+                    gt_labels[i],
+                    feats=[lvl_feat[i][None] for lvl_feat in x])
+                # record the `iou_topk`-th largest IoU in an image
+                iou_topk = min(self.iou_topk, len(assign_result.max_overlaps))
+                ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
+                cur_iou.append(ious[-1].item())
+                sampling_results.append(sampling_result)
+            # average the current IoUs over images
+            cur_iou = np.mean(cur_iou)
+            self.iou_history.append(cur_iou)
+
+        losses = dict()
+        # bbox head forward and loss
+        if self.with_bbox:
+            bbox_results = self._bbox_forward_train(x, sampling_results,
+                                                    gt_bboxes, gt_labels,
+                                                    img_metas)
+            losses.update(bbox_results['loss_bbox'])
+
+        # mask head forward and loss
+        if self.with_mask:
+            mask_results = self._mask_forward_train(x, sampling_results,
+                                                    bbox_results['bbox_feats'],
+                                                    gt_masks, img_metas)
+            # TODO: Support empty tensor input. #2280
+            if mask_results['loss_mask'] is not None:
+                losses.update(mask_results['loss_mask'])
+
+        # update IoU threshold and SmoothL1 beta
+        if len(self.iou_history) % self.update_iter_interval == 0:
+            new_iou_thr, new_beta = self.update_hyperparameters()
+
+        return losses
+
+    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+                            img_metas):
+        num_imgs = len(img_metas)
+        rois = bbox2roi([res.bboxes for res in sampling_results])
+        bbox_results = self._bbox_forward(x, rois)
+
+        bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+                                                  gt_labels, self.train_cfg)
+        # record the `beta_topk`-th smallest target
+        # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
+        # and bbox_weights, respectively
+        pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
+        num_pos = len(pos_inds)
+        cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
+        beta_topk = min(self.beta_topk * num_imgs, num_pos)
+        cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
+        self.beta_history.append(cur_target)
+        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
+                                        bbox_results['bbox_pred'], rois,
+                                        *bbox_targets)
+
+        bbox_results.update(loss_bbox=loss_bbox)
+        return bbox_results
+
+    def update_hyperparameters(self):
+        """
+        Update hyperparameters like `iou_thr` and `SmoothL1 beta` based
+        on the training statistics.
+
+        Returns:
+            tuple[float]: the updated `iou_thr` and `SmoothL1 beta`
+        """
+        new_iou_thr = max(self.initial_iou, np.mean(self.iou_history))
+        self.iou_history = []
+        self.bbox_assigner.pos_iou_thr = new_iou_thr
+        self.bbox_assigner.neg_iou_thr = new_iou_thr
+        self.bbox_assigner.min_pos_iou = new_iou_thr
+        new_beta = min(self.initial_beta, np.median(self.beta_history))
+        self.beta_history = []
+        self.bbox_head.loss_bbox.beta = new_beta
+        return new_iou_thr, new_beta