Skip to content
Snippets Groups Projects
Commit ac8d34e7 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' of github.com:open-mmlab/mmdetection

parents b88561db a054aef4
No related branches found
No related tags found
No related merge requests found
......@@ -159,16 +159,6 @@ def anchor_target_single(flat_anchors,
neg_inds)
def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0):
img_h, img_w = img_shape[:2]
......
......@@ -23,6 +23,8 @@ def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
return F.binary_cross_entropy_with_logits(
......@@ -115,3 +117,13 @@ def accuracy(pred, target, topk=1):
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment