diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py index 4cf8ac87d7ba23f076edc17f2022932778cd4c4e..1c6bc36e04851620abd2366851e1269eb96160cc 100644 --- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py @@ -199,9 +199,27 @@ def delta2bbox(rois, x2 = gx + gw * 0.5 y2 = gy + gh * 0.5 if clip_border and max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1]) - y1 = y1.clamp(min=0, max=max_shape[0]) - x2 = x2.clamp(min=0, max=max_shape[1]) - y2 = y2.clamp(min=0, max=max_shape[0]) + # use where() to replace clip(), + # because clip()'s attr min/max do not support dynamic in onnx + if torch.onnx.is_in_onnx_export(): + zero = torch.tensor(0, dtype=torch.float32) + zero = zero.expand(x1.size()) + width = torch.tensor(max_shape[1], dtype=torch.float32) + width = width.expand(x1.size()) + height = torch.tensor(max_shape[0], dtype=torch.float32) + height = height.expand(x1.size()) + x1 = torch.where(x1 < zero, zero, x1) + x1 = torch.where(x1 > width, width, x1) + y1 = torch.where(y1 < zero, zero, y1) + y1 = torch.where(y1 > height, height, y1) + x2 = torch.where(x2 < zero, zero, x2) + x2 = torch.where(x2 > width, width, x2) + y2 = torch.where(y2 < zero, zero, y2) + y2 = torch.where(y2 > height, height, y2) + else: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) return bboxes diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index 8ae50d38dd980152da402ae4c9a78b879d16c940..87e601368ca6799f19c309e4a8a10c7686078342 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -567,7 +567,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): bbox_pred_list = [ bbox_preds[i][img_id].detach() for i in range(num_levels) ] - img_shape = img_metas[img_id]['img_shape'] + # get origin input shape to support onnx dynamic shape + if torch.onnx.is_in_onnx_export(): + img_shape = img_metas[img_id]['img_shape_for_onnx'] + else: + img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] if with_nms: # some heads don't support with_nms argument diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index ce6b26ea96692f1746a239e57fdfec9119f135bf..c2bace620de5d7d372211f93fb03797d230a4772 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -111,6 +111,9 @@ class SingleStageDetector(BaseDetector): """ x = self.extract_feat(img) outs = self.bbox_head(x) + # get origin input shape to support onnx dynamic shape + if torch.onnx.is_in_onnx_export(): + img_metas[0]['img_shape_for_onnx'] = img.shape[2:] bbox_list = self.bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) # skip post-processing when exporting to ONNX