Skip to content
Snippets Groups Projects
Commit cb0dd8ee authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

support fp16 for maskiou_head (#986)

parent 713e98bc
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import kaiming_init, normal_init
from mmdet.core import force_fp32
from ..builder import build_loss
from ..registry import HEADS
......@@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module):
self.conv_out_channels = conv_out_channels
self.fc_out_channels = fc_out_channels
self.num_classes = num_classes
self.fp16_enabled = False
self.convs = nn.ModuleList()
for i in range(num_convs):
......@@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module):
mask_iou = self.fc_mask_iou(x)
return mask_iou
@force_fp32(apply_to=('mask_iou_pred', ))
def loss(self, mask_iou_pred, mask_iou_targets):
pos_inds = mask_iou_targets > 0
if pos_inds.sum() > 0:
......@@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module):
loss_mask_iou = mask_iou_pred * 0
return dict(loss_mask_iou=loss_mask_iou)
@force_fp32(apply_to=('mask_pred', ))
def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets,
rcnn_train_cfg):
"""Compute target of mask IoU.
......@@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module):
area_ratios = pos_proposals.new_zeros((0, ))
return area_ratios
@force_fp32(apply_to=('mask_iou_pred', ))
def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels):
"""Get the mask scores.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment