From 204f7514afedb41f69d08f91a3ca264cf7c5b4e4 Mon Sep 17 00:00:00 2001
From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Date: Mon, 8 Jun 2020 21:40:09 +0800
Subject: [PATCH] Fix bugs in TTA of cascade rcnn & htc (#2944)

---
 .isort.cfg                                 | 2 +-
 mmdet/models/roi_heads/cascade_roi_head.py | 3 +--
 mmdet/models/roi_heads/htc_roi_head.py     | 6 +-----
 3 files changed, 3 insertions(+), 8 deletions(-)

diff --git a/.isort.cfg b/.isort.cfg
index 0fff944e..285b7000 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -3,6 +3,6 @@ line_length = 79
 multi_line_output = 0
 known_standard_library = setuptools
 known_first_party = mmdet
-known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
+known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
 no_lines_before = STDLIB,LOCALFOLDER
 default_section = THIRDPARTY
diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py
index c528eb2a..fb40293c 100644
--- a/mmdet/models/roi_heads/cascade_roi_head.py
+++ b/mmdet/models/roi_heads/cascade_roi_head.py
@@ -393,8 +393,7 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
         if self.with_mask:
             if det_bboxes.shape[0] == 0:
                 segm_result = [[]
-                               for _ in range(self.mask_head[-1].num_classes -
-                                              1)]
+                               for _ in range(self.mask_head[-1].num_classes)]
             else:
                 aug_masks = []
                 aug_img_metas = []
diff --git a/mmdet/models/roi_heads/htc_roi_head.py b/mmdet/models/roi_heads/htc_roi_head.py
index 95c4345a..e7918243 100644
--- a/mmdet/models/roi_heads/htc_roi_head.py
+++ b/mmdet/models/roi_heads/htc_roi_head.py
@@ -376,7 +376,7 @@ class HybridTaskCascadeRoIHead(CascadeRoIHead):
 
         return results
 
-    def aug_test(self, img_feats, img_metas, proposals=None, rescale=False):
+    def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
         """Test with augmentations.
 
         If rescale is False, then returned bboxes and masks will fit the scale
@@ -389,10 +389,6 @@ class HybridTaskCascadeRoIHead(CascadeRoIHead):
         else:
             semantic_feats = [None] * len(img_metas)
 
-        # recompute feats to save memory
-        proposal_list = self.aug_test_rpn(img_feats, img_metas,
-                                          self.test_cfg.rpn)
-
         rcnn_test_cfg = self.test_cfg
         aug_bboxes = []
         aug_scores = []
-- 
GitLab