Skip to content
Snippets Groups Projects
Unverified Commit 92595ea6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

support unsquare RoIs for bbox and mask heads (#1128)

parent f92f64c2
No related branches found
No related tags found
No related merge requests found
import mmcv
import numpy as np
import torch
from torch.nn.modules.utils import _pair
def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
......@@ -13,7 +14,7 @@ def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
mask_size = cfg.mask_size
mask_size = _pair(cfg.mask_size)
num_pos = pos_proposals.size(0)
mask_targets = []
if num_pos > 0:
......@@ -26,11 +27,12 @@ def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
w = np.maximum(x2 - x1 + 1, 1)
h = np.maximum(y2 - y1 + 1, 1)
# mask is uint8 both before and after resizing
# mask_size (h, w) to (w, h)
target = mmcv.imresize(gt_mask[y1:y1 + h, x1:x1 + w],
(mask_size, mask_size))
mask_size[::-1])
mask_targets.append(target)
mask_targets = torch.from_numpy(np.stack(mask_targets)).float().to(
pos_proposals.device)
else:
mask_targets = pos_proposals.new_zeros((0, mask_size, mask_size))
mask_targets = pos_proposals.new_zeros((0, ) + mask_size)
return mask_targets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from mmdet.core import (auto_fp16, bbox_target, delta2bbox, force_fp32,
multiclass_nms)
......@@ -35,7 +36,8 @@ class BBoxHead(nn.Module):
self.with_avg_pool = with_avg_pool
self.with_cls = with_cls
self.with_reg = with_reg
self.roi_feat_size = roi_feat_size
self.roi_feat_size = _pair(roi_feat_size)
self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
self.in_channels = in_channels
self.num_classes = num_classes
self.target_means = target_means
......@@ -48,9 +50,9 @@ class BBoxHead(nn.Module):
in_channels = self.in_channels
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(roi_feat_size)
self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
else:
in_channels *= (self.roi_feat_size * self.roi_feat_size)
in_channels *= self.roi_feat_area
if self.with_cls:
self.fc_cls = nn.Linear(in_channels, num_classes)
if self.with_reg:
......
......@@ -67,9 +67,9 @@ class ConvFCBBoxHead(BBoxHead):
if self.num_shared_fcs == 0 and not self.with_avg_pool:
if self.num_cls_fcs == 0:
self.cls_last_dim *= (self.roi_feat_size * self.roi_feat_size)
self.cls_last_dim *= self.roi_feat_area
if self.num_reg_fcs == 0:
self.reg_last_dim *= (self.roi_feat_size * self.roi_feat_size)
self.reg_last_dim *= self.roi_feat_area
self.relu = nn.ReLU(inplace=True)
# reconstruct fc_cls and fc_reg since input channels are changed
......@@ -112,7 +112,7 @@ class ConvFCBBoxHead(BBoxHead):
# for separated branches, also consider self.num_shared_fcs
if (is_shared
or self.num_shared_fcs == 0) and not self.with_avg_pool:
last_layer_dim *= (self.roi_feat_size * self.roi_feat_size)
last_layer_dim *= self.roi_feat_area
for i in range(num_branch_fcs):
fc_in_channels = (
last_layer_dim if i == 0 else self.fc_out_channels)
......
......@@ -134,8 +134,8 @@ class DoubleConvFCBBoxHead(BBoxHead):
branch_fcs = nn.ModuleList()
for i in range(self.num_fcs):
fc_in_channels = (
self.in_channels * self.roi_feat_size *
self.roi_feat_size if i == 0 else self.fc_out_channels)
self.in_channels *
self.roi_feat_area if i == 0 else self.fc_out_channels)
branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))
return branch_fcs
......
......@@ -3,6 +3,7 @@ import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from mmdet.core import auto_fp16, force_fp32, mask_target
from ..builder import build_loss
......@@ -33,7 +34,8 @@ class FCNMaskHead(nn.Module):
'Invalid upsample method {}, accepted methods '
'are "deconv", "nearest", "bilinear"'.format(upsample_method))
self.num_convs = num_convs
self.roi_feat_size = roi_feat_size # WARN: not used and reserved
# WARN: roi_feat_size is reserved and not used
self.roi_feat_size = _pair(roi_feat_size)
self.in_channels = in_channels
self.conv_kernel_size = conv_kernel_size
self.conv_out_channels = conv_out_channels
......
......@@ -46,6 +46,8 @@ class GridHead(nn.Module):
raise ValueError('grid_points must be a square number')
# the predicted heatmap is half of whole_map_size
if not isinstance(self.roi_feat_size, int):
raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
self.whole_map_size = self.roi_feat_size * 4
# compute point-wise sub-regions
......
......@@ -2,6 +2,7 @@ import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import kaiming_init, normal_init
from torch.nn.modules.utils import _pair
from mmdet.core import force_fp32
from ..builder import build_loss
......@@ -47,10 +48,13 @@ class MaskIoUHead(nn.Module):
stride=stride,
padding=1))
roi_feat_size = _pair(roi_feat_size)
pooled_area = (roi_feat_size[0] // 2) * (roi_feat_size[1] // 2)
self.fcs = nn.ModuleList()
for i in range(num_fcs):
in_channels = self.conv_out_channels * (
roi_feat_size // 2)**2 if i == 0 else self.fc_out_channels
in_channels = (
self.conv_out_channels *
pooled_area if i == 0 else self.fc_out_channels)
self.fcs.append(nn.Linear(in_channels, self.fc_out_channels))
self.fc_mask_iou = nn.Linear(self.fc_out_channels, self.num_classes)
......
......@@ -94,8 +94,8 @@ class SingleRoIExtractor(nn.Module):
out_size = self.roi_layers[0].out_size
num_levels = len(feats)
target_lvls = self.map_roi_levels(rois, num_levels)
roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels,
out_size, out_size)
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)
for i in range(num_levels):
......
......@@ -65,7 +65,7 @@ class RoIAlign(nn.Module):
use_torchvision=False):
super(RoIAlign, self).__init__()
self.out_size = out_size
self.out_size = _pair(out_size)
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.use_torchvision = use_torchvision
......@@ -73,7 +73,7 @@ class RoIAlign(nn.Module):
def forward(self, features, rois):
if self.use_torchvision:
from torchvision.ops import roi_align as tv_roi_align
return tv_roi_align(features, rois, _pair(self.out_size),
return tv_roi_align(features, rois, self.out_size,
self.spatial_scale, self.sample_num)
else:
return roi_align(features, rois, self.out_size, self.spatial_scale,
......
......@@ -55,14 +55,14 @@ class RoIPool(nn.Module):
def __init__(self, out_size, spatial_scale, use_torchvision=False):
super(RoIPool, self).__init__()
self.out_size = out_size
self.out_size = _pair(out_size)
self.spatial_scale = float(spatial_scale)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
if self.use_torchvision:
from torchvision.ops import roi_pool as tv_roi_pool
return tv_roi_pool(features, rois, _pair(self.out_size),
return tv_roi_pool(features, rois, self.out_size,
self.spatial_scale)
else:
return roi_pool(features, rois, self.out_size, self.spatial_scale)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment