Skip to content
Snippets Groups Projects
Commit af55b977 authored by Kai Chen's avatar Kai Chen
Browse files

use dict to save multi-stage results

parent 0bb723a5
No related branches found
No related tags found
No related merge requests found
......@@ -142,7 +142,7 @@ train_cfg = dict(
pos_weight=-1,
debug=False)
],
loss_weight=[1, 0.5, 0.4])
stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
......
......@@ -128,7 +128,7 @@ train_cfg = dict(
pos_weight=-1,
debug=False)
],
loss_weight=[1, 0.5, 0.4])
stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
......
from __future__ import division
import torch
import torch.nn as nn
......@@ -127,7 +129,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for i in range(self.num_stages):
rcnn_train_cfg = self.train_cfg.rcnn[i]
lw = self.train_cfg.loss_weight[i]
lw = self.train_cfg.stage_loss_weights[i]
# assign gts and sample proposals
assign_results, sampling_results = multi_apply(
......@@ -193,8 +195,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
scale_factor = img_meta[0]['scale_factor']
# "ms" in variable names means multi-stage
ms_bbox_result = []
ms_segm_result = []
ms_bbox_result = {}
ms_segm_result = {}
ms_scores = []
rcnn_test_cfg = self.test_cfg.rcnn
......@@ -219,11 +221,11 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
nms_cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels,
bbox_head.num_classes)
ms_bbox_result.append(bbox_result)
ms_bbox_result['stage{}'.format(i)] = bbox_result
if self.with_mask:
mask_block = self.mask_blocks[i]
mask_head = self.mask_heads[i]
mask_roi_extractor = self.mask_roi_extractor[i]
mask_head = self.mask_head[i]
if det_bboxes.shape[0] == 0:
segm_result = [
[] for _ in range(mask_head.num_classes - 1)
......@@ -232,20 +234,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
_bboxes = (det_bboxes[:, :4] * scale_factor
if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = mask_block(
x[:len(mask_block.featmap_strides)], mask_rois)
mask_feats = mask_roi_extractor(
x[:len(mask_roi_extractor.featmap_strides)],
mask_rois)
mask_pred = mask_head(mask_feats)
segm_result = mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale)
ms_segm_result.append(segm_result)
ms_segm_result['stage{}'.format(i)] = segm_result
if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
img_meta[0])
cls_score = sum(ms_scores) / float(len(ms_scores))
cls_score = sum(ms_scores) / len(ms_scores)
det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
......@@ -256,7 +259,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
nms_cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)
ms_bbox_result.append(bbox_result)
ms_bbox_result['ensemble'] = bbox_result
if self.with_mask:
if det_bboxes.shape[0] == 0:
......@@ -280,12 +283,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks, _bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale)
ms_segm_result.append(segm_result)
ms_segm_result['ensemble'] = segm_result
if not self.test_cfg.keep_all_stages:
ms_bbox_result = ms_bbox_result[0]
ms_bbox_result = ms_bbox_result['ensemble']
if self.with_mask:
ms_segm_result = ms_segm_result[0]
ms_segm_result = ms_segm_result['ensemble']
if not self.with_mask:
return ms_bbox_result
......@@ -301,5 +304,9 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
ms_bbox_result, ms_segm_result = result
else:
ms_bbox_result = result
super(CascadeRCNN, self).show_result(data, ms_bbox_result[-1],
img_norm_cfg, **kwargs)
if isinstance(ms_bbox_result, dict):
bbox_result = ms_bbox_result['ensemble']
else:
bbox_result = ms_bbox_result
super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg,
**kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment