Skip to content
Snippets Groups Projects
Commit 427c8902 authored by pangjm's avatar pangjm
Browse files

add Faster RCNN & Mask RCNN training API and some test related

parent 65642939
No related branches found
No related tags found
No related merge requests found
from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps,
bbox_sampling, sample_positives, sample_negatives,
sample_proposals)
bbox_sampling, sample_positives, sample_negatives)
from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip,
bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox,
bbox2result)
......@@ -12,5 +11,5 @@ __all__ = [
'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives',
'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip',
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result',
'bbox_target', 'sample_proposals'
'bbox_target'
]
......@@ -58,7 +58,7 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='sum')[None]
pred_slice, target, reduction='elementwise_mean')[None]
def weighted_mask_cross_entropy(pred, target, weight, label):
......
from .segms import (flip_segms, polys_to_mask, mask_to_bbox,
polys_to_mask_wrt_box, polys_to_boxes, rle_mask_voting,
rle_mask_nms, rle_masks_to_boxes)
from .utils import split_combined_gt_polys
from .utils import split_combined_polys
from .mask_target import mask_target
__all__ = [
'flip_segms', 'polys_to_mask', 'mask_to_bbox', 'polys_to_mask_wrt_box',
'polys_to_boxes', 'rle_mask_voting', 'rle_mask_nms', 'rle_masks_to_boxes',
'split_combined_gt_polys', 'mask_target'
'split_combined_polys', 'mask_target'
]
......@@ -4,27 +4,31 @@ import numpy as np
from .segms import polys_to_mask_wrt_box
def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_polys_list,
img_meta, cfg):
def mask_target(pos_proposals_list,
pos_assigned_gt_inds_list,
gt_polys_list,
img_meta,
cfg):
cfg_list = [cfg for _ in range(len(pos_proposals_list))]
img_metas = [img_meta for _ in range(len(pos_proposals_list))]
mask_targets = map(mask_target_single, pos_proposals_list,
pos_assigned_gt_inds_list, gt_polys_list, img_metas,
pos_assigned_gt_inds_list, gt_polys_list, img_meta,
cfg_list)
mask_targets = torch.cat(tuple(mask_targets), dim=0)
return mask_targets
def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_polys,
img_meta, cfg):
def mask_target_single(pos_proposals,
pos_assigned_gt_inds,
gt_polys,
img_meta,
cfg):
mask_size = cfg.mask_size
num_pos = pos_proposals.size(0)
mask_targets = pos_proposals.new_zeros((num_pos, mask_size, mask_size))
if num_pos > 0:
pos_proposals = pos_proposals.cpu().numpy()
pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
scale_factor = img_meta['scale_factor'][0].cpu().numpy()
scale_factor = img_meta['scale_factor']
for i in range(num_pos):
bbox = pos_proposals[i, :] / scale_factor
polys = gt_polys[pos_assigned_gt_inds[i]]
......
import mmcv
def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
def split_combined_polys(polys, poly_lens, polys_per_mask):
"""Split the combined 1-D polys into masks.
A mask is represented as a list of polys, and a poly is represented as
......@@ -9,9 +9,9 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
tensor. Here we need to split the tensor into original representations.
Args:
gt_polys (list): a list (length = image num) of 1-D tensors
gt_poly_lens (list): a list (length = image num) of poly length
num_polys_per_mask (list): a list (length = image num) of poly number
polys (list): a list (length = image num) of 1-D tensors
poly_lens (list): a list (length = image num) of poly length
polys_per_mask (list): a list (length = image num) of poly number
of each mask
Returns:
......@@ -19,13 +19,12 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
list (length = poly num) of numpy array
"""
mask_polys_list = []
for img_id in range(len(gt_polys)):
gt_polys_single = gt_polys[img_id].cpu().numpy()
gt_polys_lens_single = gt_poly_lens[img_id].cpu().numpy().tolist()
num_polys_per_mask_single = num_polys_per_mask[
img_id].cpu().numpy().tolist()
for img_id in range(len(polys)):
polys_single = polys[img_id]
polys_lens_single = poly_lens[img_id].tolist()
polys_per_mask_single = polys_per_mask[img_id].tolist()
split_gt_polys = mmcv.slice_list(gt_polys_single, gt_polys_lens_single)
mask_polys = mmcv.slice_list(split_gt_polys, num_polys_per_mask_single)
split_polys = mmcv.slice_list(polys_single, polys_lens_single)
mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single)
mask_polys_list.append(mask_polys)
return mask_polys_list
from .base import BaseDetector
from .rpn import RPN
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
__all__ = ['BaseDetector', 'RPN']
__all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN']
from .two_stage import TwoStageDetector
class FasterRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from .two_stage import TwoStageDetector
class MaskRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
mask_roi_extractor,
mask_head,
train_cfg,
test_cfg,
pretrained=None):
super(MaskRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
import torch
import torch.nn as nn
from .base import Detector
from .testing_mixins import RPNTestMixin, BBoxTestMixin
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from mmdet.core import bbox2roi, bbox2result, sample_proposals
from mmdet.core import bbox2roi, bbox2result, split_combined_polys, multi_apply
class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
def __init__(self,
backbone,
......@@ -15,13 +16,16 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(Detector, self).__init__()
super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.with_neck = True if neck is not None else False
assert self.with_neck, "TwoStageDetector must be implemented with FPN now."
if self.with_neck:
self.neck = builder.build_neck(neck)
......@@ -35,6 +39,12 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
self.with_mask = True if mask_head is not None else False
if self.with_mask:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -68,6 +78,7 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_masks=None,
proposals=None):
losses = dict()
......@@ -80,54 +91,73 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.self.test_cfg.rpn)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
else:
proposal_list = proposals
(pos_inds, neg_inds, pos_proposals, neg_proposals,
pos_assigned_gt_inds,
pos_gt_bboxes, pos_gt_labels) = sample_proposals(
proposal_list, gt_bboxes, gt_bboxes_ignore, gt_labels,
self.train_cfg.rcnn)
labels, label_weights, bbox_targets, bbox_weights = \
self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
if self.with_bbox:
rcnn_train_cfg_list = [
self.train_cfg.rcnn for _ in range(len(proposal_list))
]
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
pos_gt_labels) = multi_apply(
self.bbox_roi_extractor.sample_proposals, proposal_list,
gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list)
labels, label_weights, bbox_targets, bbox_weights = \
self.bbox_head.get_bbox_target(pos_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_labels, self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
if self.with_mask:
gt_polys = split_combined_polys(**gt_masks)
mask_targets = self.mask_head.get_mask_target(
pos_proposals, pos_assigned_gt_inds, gt_polys, img_meta,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois)
mask_pred = self.mask_head(mask_feats)
loss_mask = self.mask_head.loss(mask_pred, mask_targets,
torch.cat(pos_gt_labels))
losses.update(loss_mask)
return losses
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
assert proposals == None, "Fast RCNN hasn't been implemented."
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
if proposals is None:
proposals = self.simple_test_rpn(x, img_meta)
if self.with_bbox:
# BUG proposals shape?
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, [proposals], rescale=rescale)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_result
proposal_list = self.simple_test_rpn(
x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
if self.with_mask:
segm_results = self.simple_test_mask(
x, img_meta, det_bboxes, det_labels, rescale=rescale)
return bbox_results, segm_results
else:
proposals[:, :4] /= img_meta['scale_factor'].float()
return proposals.cpu().numpy()
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
......@@ -135,15 +165,28 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
proposals = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.rpn_test_cfg)
# recompute self.extract_feats(imgs) because of 'yield' and memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
det_bboxes, det_labels = self.aug_test_bboxes(
self.extract_feats(imgs), img_metas, proposals, self.rcnn_test_cfg)
self.extract_feats(imgs), img_metas, proposal_list,
self.test_cfg.rcnn)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0]['shape_scale'][0][-1]
bbox_result = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_result
_det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
# det_bboxes always keep the original scale
if self.with_mask:
segm_results = self.aug_test_mask(
self.extract_feats(imgs),
img_metas,
det_bboxes,
det_labels)
return bbox_results, segm_results
else:
return bbox_results
......@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return mask_targets
def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
return loss_mask
loss['loss_mask'] = loss_mask
return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_scale):
ori_shape):
"""Get segmentation masks from mask_pred and bboxes
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
......@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to original image size
ori_shape: original image size
Returns:
list[list]: encoded masks
"""
......@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms = [[] for _ in range(self.num_classes - 1)]
bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1
img_h = ori_scale[0]
img_w = ori_scale[1]
img_h = ori_shape[0]
img_w = ori_shape[1]
for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int)
......
......@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
from mmdet import ops
from mmdet.core import bbox_assign, bbox_sampling
class SingleLevelRoI(nn.Module):
......@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
def sample_proposals(self, proposals, gt_bboxes, gt_crowds, gt_labels,
cfg):
proposals = proposals[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(proposals, gt_bboxes, gt_crowds, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
if cfg.add_gt_as_proposals:
proposals = torch.cat([gt_bboxes, proposals], dim=0)
gt_assign_self = torch.arange(
1,
len(gt_labels) + 1,
dtype=torch.long,
device=proposals.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction,
cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
pos_proposals = proposals[pos_inds]
neg_proposals = proposals[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)
def forward(self, feats, rois):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment