Skip to content
Snippets Groups Projects
Commit cb0dd8ee authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

support fp16 for maskiou_head (#986)

parent 713e98bc
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import kaiming_init, normal_init from mmcv.cnn import kaiming_init, normal_init
from mmdet.core import force_fp32
from ..builder import build_loss from ..builder import build_loss
from ..registry import HEADS from ..registry import HEADS
...@@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module): ...@@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module):
self.conv_out_channels = conv_out_channels self.conv_out_channels = conv_out_channels
self.fc_out_channels = fc_out_channels self.fc_out_channels = fc_out_channels
self.num_classes = num_classes self.num_classes = num_classes
self.fp16_enabled = False
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
for i in range(num_convs): for i in range(num_convs):
...@@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module): ...@@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module):
mask_iou = self.fc_mask_iou(x) mask_iou = self.fc_mask_iou(x)
return mask_iou return mask_iou
@force_fp32(apply_to=('mask_iou_pred', ))
def loss(self, mask_iou_pred, mask_iou_targets): def loss(self, mask_iou_pred, mask_iou_targets):
pos_inds = mask_iou_targets > 0 pos_inds = mask_iou_targets > 0
if pos_inds.sum() > 0: if pos_inds.sum() > 0:
...@@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module): ...@@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module):
loss_mask_iou = mask_iou_pred * 0 loss_mask_iou = mask_iou_pred * 0
return dict(loss_mask_iou=loss_mask_iou) return dict(loss_mask_iou=loss_mask_iou)
@force_fp32(apply_to=('mask_pred', ))
def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets, def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets,
rcnn_train_cfg): rcnn_train_cfg):
"""Compute target of mask IoU. """Compute target of mask IoU.
...@@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module): ...@@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module):
area_ratios = pos_proposals.new_zeros((0, )) area_ratios = pos_proposals.new_zeros((0, ))
return area_ratios return area_ratios
@force_fp32(apply_to=('mask_iou_pred', ))
def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels): def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels):
"""Get the mask scores. """Get the mask scores.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment