Commit be792b95 authored by chenbohua3's avatar chenbohua3 Committed by Facebook GitHub Bot
Browse files

Make 'ROIAlign' & 'ROIAlignV2' version of ROIPooler scriptable.

Summary: Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/1835

Reviewed By: rbgirshick

Differential Revision: D22819550

Pulled By: ppwwyyxx

fbshipit-source-id: 85cd2198676289e0ab02678f221b97887e543395
parent af866c42
......@@ -7,12 +7,25 @@ from torch import nn
from torchvision.ops import RoIPool
from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple
from detectron2.structures import Boxes
"""
To export ROIPooler to torchscript, in this file, variables that should be annotated with
`Union[List[Boxes], List[RotatedBoxes]]` are only annotated with `List[Boxes]`.
TODO: Correct these annotations when torchscript support `Union`.
https://github.com/pytorch/pytorch/issues/41412
"""
__all__ = ["ROIPooler"]
def assign_boxes_to_levels(
box_lists, min_level: int, max_level: int, canonical_box_size: int, canonical_level: int
box_lists: List[Boxes],
min_level: int,
max_level: int,
canonical_box_size: int,
canonical_level: int,
):
"""
Map each box in `box_lists` to a feature map level index and return the assignment
......@@ -35,11 +48,10 @@ def assign_boxes_to_levels(
`self.min_level`, for the corresponding box (so value i means the box is at
`self.min_level + i`).
"""
eps = sys.float_info.epsilon
box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists]))
# Eqn.(1) in FPN paper
level_assignments = torch.floor(
canonical_level + torch.log2(box_sizes / canonical_box_size + eps)
canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
)
# clamp level to (min, max), in case the box size is too large or too small
# for the available feature maps
......@@ -47,7 +59,14 @@ def assign_boxes_to_levels(
return level_assignments.to(torch.int64) - min_level
def convert_boxes_to_pooler_format(box_lists):
def _fmt_box_list(box_tensor, batch_index: int):
repeated_index = torch.full(
(len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device
)
return cat((repeated_index, box_tensor), dim=1)
def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
"""
Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops
(see description under Returns).
......@@ -70,15 +89,8 @@ def convert_boxes_to_pooler_format(box_lists):
where batch index is the index in [0, N) identifying which batch image the
rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from.
"""
def fmt_box_list(box_tensor, batch_index):
repeated_index = torch.full(
(len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device
)
return cat((repeated_index, box_tensor), dim=1)
pooler_fmt_boxes = cat(
[fmt_box_list(box_list.tensor, i) for i, box_list in enumerate(box_lists)], dim=0
[_fmt_box_list(box_list.tensor, i) for i, box_list in enumerate(box_lists)], dim=0
)
return pooler_fmt_boxes
......@@ -176,7 +188,7 @@ class ROIPooler(nn.Module):
assert canonical_box_size > 0
self.canonical_box_size = canonical_box_size
def forward(self, x: List[torch.Tensor], box_lists):
def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
"""
Args:
x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
......@@ -226,9 +238,9 @@ class ROIPooler(nn.Module):
(num_boxes, num_channels, output_size, output_size), dtype=dtype, device=device
)
for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
for level, pooler in enumerate(self.level_poolers):
inds = nonzero_tuple(level_assignments == level)[0]
pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
output[inds] = pooler(x_level, pooler_fmt_boxes_level)
output[inds] = pooler(x[level], pooler_fmt_boxes_level)
return output
......@@ -5,6 +5,7 @@ import torch
from detectron2.modeling.poolers import ROIPooler
from detectron2.structures import Boxes, RotatedBoxes
from detectron2.utils.env import TORCH_VERSION
logger = logging.getLogger(__name__)
......@@ -80,6 +81,49 @@ class TestROIPooler(unittest.TestCase):
def test_roialignv2_roialignrotated_match_cuda(self):
self._test_roialignv2_roialignrotated_match(device="cuda")
def _test_scriptability(self, device):
pooler_resolution = 14
canonical_level = 4
canonical_scale_factor = 2 ** canonical_level
pooler_scales = (1.0 / canonical_scale_factor,)
sampling_ratio = 0
N, C, H, W = 2, 4, 10, 8
N_rois = 10
std = 11
mean = 0
feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean
features = [feature.to(device)]
rois = []
for _ in range(N):
boxes = self._rand_boxes(
num_boxes=N_rois, x_max=W * canonical_scale_factor, y_max=H * canonical_scale_factor
)
rois.append(Boxes(boxes).to(device))
roialignv2_pooler = ROIPooler(
output_size=pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type="ROIAlignV2",
)
roialignv2_out = roialignv2_pooler(features, rois)
scripted_roialignv2_out = torch.jit.script(roialignv2_pooler)(features, rois)
self.assertTrue(torch.equal(roialignv2_out, scripted_roialignv2_out))
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_scriptability_cpu(self):
self._test_scriptability(device="cpu")
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_scriptability_gpu(self):
self._test_scriptability(device="cuda")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment