Skip to content
Snippets Groups Projects
FGFI.py 9.13 KiB
Newer Older
wanggh's avatar
wanggh committed
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmdet.models.detectors.base import BaseDetector
from mmdet.models import build_detector
from mmcv.runner import  load_checkpoint, _load_checkpoint, load_state_dict
from ..builder import DISTILLER,build_distill_loss
from collections import OrderedDict




@DISTILLER.register_module()
class FGFI(BaseDetector):
    """FGFI distiller for detectors.
    It typically consists of teacher_model and student_model.
    """
    def __init__(self,
                 teacher_cfg,
                 student_cfg,
                 distill_cfg=None,
                 teacher_pretrained=None,
wanggh's avatar
wanggh committed
                 load_teacher_part=None):
wanggh's avatar
wanggh committed

        super(FGFI, self).__init__()
        
        self.teacher = build_detector(teacher_cfg.model,
                                        train_cfg=teacher_cfg.get('train_cfg'),
                                        test_cfg=teacher_cfg.get('test_cfg'))
        self.init_weights_teacher(teacher_pretrained)

        
        self.teacher.eval()
        self.student= build_detector(student_cfg.model,
                                        train_cfg=student_cfg.get('train_cfg'),
                                        test_cfg=student_cfg.get('test_cfg'))
wanggh's avatar
wanggh committed
        self.init_weights_student(load_teacher_part, teacher_pretrained)
wanggh's avatar
wanggh committed

        self.distill_losses = nn.ModuleDict()
        self.distill_cfg = distill_cfg
        for item_loc in distill_cfg:
            for item_loss in item_loc.methods:
                loss_name = item_loss.name
                self.distill_losses[loss_name] = build_distill_loss(item_loss)

    def base_parameters(self):
        return nn.ModuleList([self.student, self.distill_losses])
        
    def discriminator_parameters(self):
        return self.discriminator

    @property
    def with_neck(self):
        """bool: whether the detector has a neck"""
        return hasattr(self.student, 'neck') and self.student.neck is not None

    # TODO: these properties need to be carefully handled
    # for both single stage & two stage detectors
    @property
    def with_shared_head(self):
        """bool: whether the detector has a shared head in the RoI Head"""
        return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head

    @property
    def with_bbox(self):
        """bool: whether the detector has a bbox head"""
        return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox)
                or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None))

    @property
    def with_mask(self):
        """bool: whether the detector has a mask head"""
        return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask)
                or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None))

    def init_weights_teacher(self, path=None):
        """Load the pretrained model in teacher detector.
        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        checkpoint = load_checkpoint(self.teacher, path, map_location='cpu')

    def init_weights_student(self, load_teacher_part, teacher_pretrained):
        self.student.init_weights()
        if load_teacher_part:
            assert load_teacher_part in ['neck', 'head', 'neck_head']
            def check_key(key, load_teacher_part):
                if 'neck' in key and 'neck' in load_teacher_part:
                    return True
                elif 'head' in key and 'head' in load_teacher_part:
                    return True
                else:
                    return False

            t_checkpoint = _load_checkpoint(teacher_pretrained, map_location='cpu')
            all_name = []
            for name, v in t_checkpoint["state_dict"].items():
                if check_key(name, load_teacher_part):
                    all_name.append((name, v))

            state_dict = OrderedDict(all_name)
            load_state_dict(self.student, state_dict)

wanggh's avatar
wanggh committed
    def _map_roi_levels(self, rois, num_levels):
        scale = torch.sqrt(
            (rois[:, 2] - rois[:, 0] + 1) * (rois[:, 3] - rois[:, 1] + 1))
        target_lvls = torch.floor(torch.log2(scale / 56 + 1e-6))
        target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
        return target_lvls


wanggh's avatar
wanggh committed
    def get_roi_mask(self, cls_scores, img_metas, gt_bboxes, phi=0.5):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        from mmdet.core import bbox_overlaps
        with torch.no_grad():
            anchor_list, _ = self.student.rpn_head.get_anchors(featmap_sizes, img_metas)
            mask_batch = []
            for batch in range(len(gt_bboxes)):
                mask_level = []
wanggh's avatar
wanggh committed
                target_lvls = self._map_roi_levels(gt_bboxes[batch], len(anchor_list[batch]))
wanggh's avatar
wanggh committed
                for level in range(len(anchor_list[batch])):
                    gt_level = gt_bboxes[batch][target_lvls==level]
                    h, w = featmap_sizes[level][0], featmap_sizes[level][1]
wanggh's avatar
wanggh committed
                    mask_per_img = torch.zeros([h, w], dtype=torch.float).cuda()
wanggh's avatar
wanggh committed
                    if gt_level.shape[0] > 0:
                        IoU_map = bbox_overlaps(anchor_list[batch][level], gt_level)
                        max_iou, _ = torch.max(IoU_map, dim=0)
                        IoU_map = IoU_map.view(h, w, self.student.rpn_head.num_anchors, -1)
                        for ins in range(gt_level.shape[0]):
                            max_iou_per_gt = max_iou[ins] * phi
                            mask_per_gt = torch.sum(IoU_map[:,:,:,ins] > max_iou_per_gt, dim = 2)
                            mask_per_img += mask_per_gt
wanggh's avatar
wanggh committed
                        mask_per_img = (mask_per_img > 0).float()
wanggh's avatar
wanggh committed
                    mask_level.append(mask_per_img)
                mask_batch.append(mask_level)
            mask_batch_level = []
            for i in range(len(mask_batch[0])):
                tmp = []
                for batch in range(len(mask_batch)):
                    tmp.append(mask_batch[batch][i])
                mask_batch_level.append(torch.stack(tmp, dim=0))
                
        return mask_batch_level


    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None,
                      **kwargs):
        """
        Args:
            img (Tensor): Input images of shape (N, C, H, W).
                Typically these should be mean centered and std scaled.
            img_metas (list[dict]): A List of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                :class:`mmdet.datasets.pipelines.Collect`.
        Returns:
            dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
        """
        losses = dict()
        with torch.no_grad():
            self.teacher.eval()
            f_t = self.teacher.backbone(img)
            f_t = self.teacher.neck(f_t)
wanggh's avatar
wanggh committed
            
wanggh's avatar
wanggh committed
        f_s = self.student.backbone(img)
        f_s = self.student.neck(f_s)

        rpn_outs = self.student.rpn_head(f_s)

        loss_inputs = rpn_outs + (gt_bboxes, img_metas)

        rpn_losses = self.student.rpn_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        proposal_cfg = self.student.train_cfg.get('rpn_proposal', self.student.test_cfg.rpn)
        proposal_list = self.student.rpn_head.get_bboxes(*rpn_outs, img_metas, cfg=proposal_cfg)

        losses.update(rpn_losses)

wanggh's avatar
wanggh committed
        with torch.no_grad():
            neck_mask_batch = self.get_roi_mask(rpn_outs[0], img_metas, gt_bboxes, phi=0.5)
wanggh's avatar
wanggh committed

        
        for item_loc in self.distill_cfg:
            feature_level = item_loc.feature_level
wanggh's avatar
wanggh committed
            f_s_l = f_s[feature_level]
            f_t_l = f_t[feature_level]
            mask = neck_mask_batch[feature_level]
            mask = mask.unsqueeze(1).repeat(1, f_s_l.size(1), 1, 1)
wanggh's avatar
wanggh committed
            losses['{}_ratio'.format(feature_level)] = mask.sum() / mask.numel()
wanggh's avatar
wanggh committed
            for item_loss in item_loc.methods:
                loss_name = item_loss.name
wanggh's avatar
wanggh committed
                if 'n' in loss_name:
                    mask = 1 - mask
                losses[loss_name] = self.distill_losses[loss_name](f_s_l, f_t_l, mask)
wanggh's avatar
wanggh committed


wanggh's avatar
wanggh committed

wanggh's avatar
wanggh committed
        roi_losses = self.student.roi_head.forward_train(f_s, img_metas, proposal_list,
wanggh's avatar
wanggh committed
                                                 gt_bboxes, gt_labels,
                                                 gt_bboxes_ignore, gt_masks,
                                                 **kwargs)
        losses.update(roi_losses)
        return losses
    
    def simple_test(self, img, img_metas, **kwargs):
        return self.student.simple_test(img, img_metas, **kwargs)

    def aug_test(self, imgs, img_metas, **kwargs):
        return self.student.aug_test(img, img_metas, **kwargs)

    def extract_feat(self, imgs):
        """Extract features from images."""
        return self.student.extract_feat(imgs)