Skip to content
Snippets Groups Projects
Unverified Commit 204f7514 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Fix bugs in TTA of cascade rcnn & htc (#2944)

parent 0372b3e8
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,6 @@ line_length = 79 ...@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmdet known_first_party = mmdet
known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
...@@ -393,8 +393,7 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): ...@@ -393,8 +393,7 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
if self.with_mask: if self.with_mask:
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
segm_result = [[] segm_result = [[]
for _ in range(self.mask_head[-1].num_classes - for _ in range(self.mask_head[-1].num_classes)]
1)]
else: else:
aug_masks = [] aug_masks = []
aug_img_metas = [] aug_img_metas = []
......
...@@ -376,7 +376,7 @@ class HybridTaskCascadeRoIHead(CascadeRoIHead): ...@@ -376,7 +376,7 @@ class HybridTaskCascadeRoIHead(CascadeRoIHead):
return results return results
def aug_test(self, img_feats, img_metas, proposals=None, rescale=False): def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
"""Test with augmentations. """Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale If rescale is False, then returned bboxes and masks will fit the scale
...@@ -389,10 +389,6 @@ class HybridTaskCascadeRoIHead(CascadeRoIHead): ...@@ -389,10 +389,6 @@ class HybridTaskCascadeRoIHead(CascadeRoIHead):
else: else:
semantic_feats = [None] * len(img_metas) semantic_feats = [None] * len(img_metas)
# recompute feats to save memory
proposal_list = self.aug_test_rpn(img_feats, img_metas,
self.test_cfg.rpn)
rcnn_test_cfg = self.test_cfg rcnn_test_cfg = self.test_cfg
aug_bboxes = [] aug_bboxes = []
aug_scores = [] aug_scores = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment