Skip to content
Snippets Groups Projects
Unverified Commit 046db20e authored by tianyuandu's avatar tianyuandu Committed by GitHub
Browse files

fix bug of gaussian_target, update unittest of heatmap (#3543)

* fix bug of gaussian_target, update unittest

* fix AELoss's weight; fix Conv's init in CornerHead; now the loss and mAP from 1 to 40 epochs are correct, still training

* add some comments

* fix cases

* update comments

* fix yapf
parent 49b3f37d
No related branches found
No related tags found
No related merge requests found
......@@ -23,8 +23,8 @@ model = dict(
type='GaussianFocalLoss', alpha=2.0, gamma=4.0, loss_weight=1),
loss_embedding=dict(
type='AssociativeEmbeddingLoss',
pull_weight=0.25,
push_weight=0.25),
pull_weight=0.10,
push_weight=0.10),
loss_offset=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1)))
# data settings
img_norm_cfg = dict(
......
......@@ -23,8 +23,8 @@ model = dict(
type='GaussianFocalLoss', alpha=2.0, gamma=4.0, loss_weight=1),
loss_embedding=dict(
type='AssociativeEmbeddingLoss',
pull_weight=0.25,
push_weight=0.25),
pull_weight=0.10,
push_weight=0.10),
loss_offset=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1)))
# data settings
img_norm_cfg = dict(
......
......@@ -23,8 +23,8 @@ model = dict(
type='GaussianFocalLoss', alpha=2.0, gamma=4.0, loss_weight=1),
loss_embedding=dict(
type='AssociativeEmbeddingLoss',
pull_weight=0.25,
push_weight=0.25),
pull_weight=0.10,
push_weight=0.10),
loss_offset=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1)))
# data settings
img_norm_cfg = dict(
......
......@@ -216,8 +216,18 @@ class CornerHead(BaseDenseHead):
"""Initialize weights of the head."""
bias_init = bias_init_with_prob(0.1)
for i in range(self.num_feat_levels):
# The initialization of parameters are different between nn.Conv2d
# and ConvModule. Our experiments show that using the original
# initialization of nn.Conv2d increases the final mAP by about 0.2%
self.tl_heat[i][-1].conv.reset_parameters()
self.tl_heat[i][-1].conv.bias.data.fill_(bias_init)
self.br_heat[i][-1].conv.reset_parameters()
self.br_heat[i][-1].conv.bias.data.fill_(bias_init)
self.tl_off[i][-1].conv.reset_parameters()
self.br_off[i][-1].conv.reset_parameters()
if self.with_corner_emb:
self.tl_emb[i][-1].conv.reset_parameters()
self.br_emb[i][-1].conv.reset_parameters()
def forward(self, feats):
"""Forward features from the upstream network.
......@@ -768,7 +778,7 @@ class CornerHead(BaseDenseHead):
feat (Tensor): Gathered feature.
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
ind = ind.unsqueeze(2).repeat(1, 1, dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
......@@ -898,10 +908,14 @@ class CornerHead(BaseDenseHead):
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = self._topk(tl_heat, k=k)
br_scores, br_inds, br_clses, br_ys, br_xs = self._topk(br_heat, k=k)
tl_ys = tl_ys.view(batch, k, 1).expand(batch, k, k)
tl_xs = tl_xs.view(batch, k, 1).expand(batch, k, k)
br_ys = br_ys.view(batch, 1, k).expand(batch, k, k)
br_xs = br_xs.view(batch, 1, k).expand(batch, k, k)
# We use repeat instead of expand here because expand is a
# shallow-copy function. Thus it could cause unexpected testing result
# sometimes. Using expand will decrease about 10% mAP during testing
# compared to repeat.
tl_ys = tl_ys.view(batch, k, 1).repeat(1, 1, k)
tl_xs = tl_xs.view(batch, k, 1).repeat(1, 1, k)
br_ys = br_ys.view(batch, 1, k).repeat(1, k, 1)
br_xs = br_xs.view(batch, 1, k).repeat(1, k, 1)
tl_off = self._transpose_and_gather_feat(tl_off, tl_inds)
tl_off = tl_off.view(batch, k, 1, 2)
......@@ -1002,14 +1016,14 @@ class CornerHead(BaseDenseHead):
br_emb = br_emb.view(batch, 1, k)
dists = torch.abs(tl_emb - br_emb)
tl_scores = tl_scores.view(batch, k, 1).expand(batch, k, k)
br_scores = br_scores.view(batch, 1, k).expand(batch, k, k)
tl_scores = tl_scores.view(batch, k, 1).repeat(1, 1, k)
br_scores = br_scores.view(batch, 1, k).repeat(1, k, 1)
scores = (tl_scores + br_scores) / 2 # scores for all possible boxes
# tl and br should have same class
tl_clses = tl_clses.view(batch, k, 1).expand(batch, k, k)
br_clses = br_clses.view(batch, 1, k).expand(batch, k, k)
tl_clses = tl_clses.view(batch, k, 1).repeat(1, 1, k)
br_clses = br_clses.view(batch, 1, k).repeat(1, k, 1)
cls_inds = (tl_clses != br_clses)
# reject boxes based on distances
......
......@@ -54,7 +54,7 @@ def gen_gaussian_target(heatmap, center, radius, k=1):
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian_kernel[radius - top:radius + bottom,
radius - left:radius + right]
out_heatmap = torch.zeros_like(heatmap)
out_heatmap = heatmap
torch.max(
masked_heatmap,
masked_gaussian * k,
......
......@@ -2,6 +2,7 @@ import mmcv
import torch
from mmdet.core import bbox2roi, build_assigner, build_sampler
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.models.dense_heads import (AnchorHead, CornerHead, FCOSHead,
FSAFHead, GuidedAnchorHead)
from mmdet.models.roi_heads.bbox_heads import BBoxHead
......@@ -714,8 +715,11 @@ def test_corner_head_encode_and_decode_heatmap():
'border': (0, 0, 0, 0)
}]
gt_bboxes = [torch.Tensor([[10, 20, 200, 240]])]
gt_labels = [torch.LongTensor([1])]
gt_bboxes = [
torch.Tensor([[10, 20, 200, 240], [40, 50, 100, 200],
[10, 20, 200, 240]])
]
gt_labels = [torch.LongTensor([1, 1, 2])]
self = CornerHead(num_classes=4, in_channels=1, corner_emb_channels=1)
......@@ -762,5 +766,14 @@ def test_corner_head_encode_and_decode_heatmap():
scores = scores[idx].view(-1)
clses = clses[idx].view(-1)
assert bboxes[torch.where(scores > 0.05)].equal(gt_bboxes[0])
assert clses[torch.where(scores > 0.05)].equal(gt_labels[0].float())
valid_bboxes = bboxes[torch.where(scores > 0.05)]
valid_labels = clses[torch.where(scores > 0.05)]
max_coordinate = valid_bboxes.max()
offsets = valid_labels.to(valid_bboxes) * (max_coordinate + 1)
gt_offsets = gt_labels[0].to(gt_bboxes[0]) * (max_coordinate + 1)
offset_bboxes = valid_bboxes + offsets[:, None]
offset_gtbboxes = gt_bboxes[0] + gt_offsets[:, None]
iou_matrix = bbox_overlaps(offset_bboxes.numpy(), offset_gtbboxes.numpy())
assert (iou_matrix == 1).sum() == 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment