From 5a0e55af5a030ad6db85243460d187383be87a36 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU <xvjiarui0826@gmail.com> Date: Sun, 3 May 2020 00:50:33 +0800 Subject: [PATCH] Migrate to mmcv.cnn.bricks (#2572) * migrate to mmcv.cnn.bricks * remove conv_module, scale, upsample * update mmdet version and mmcv requirement * fixed fsaf warning --- .../bbox/assigners/center_region_assigner.py | 2 +- mmdet/core/bbox/coder/tblr_bbox_coder.py | 2 +- mmdet/models/anchor_heads/atss_head.py | 3 +- mmdet/models/anchor_heads/fcos_head.py | 3 +- mmdet/models/anchor_heads/fovea_head.py | 4 +- mmdet/models/anchor_heads/fsaf_head.py | 2 +- mmdet/models/anchor_heads/ga_retina_head.py | 4 +- mmdet/models/anchor_heads/reppoints_head.py | 4 +- mmdet/models/anchor_heads/retina_head.py | 3 +- .../models/anchor_heads/retina_sepbn_head.py | 3 +- mmdet/models/backbones/hrnet.py | 4 +- mmdet/models/backbones/resnet.py | 5 +- mmdet/models/backbones/resnext.py | 3 +- mmdet/models/bbox_heads/convfc_bbox_head.py | 2 +- mmdet/models/bbox_heads/double_bbox_head.py | 3 +- mmdet/models/detectors/fsaf.py | 2 +- mmdet/models/mask_heads/fcn_mask_head.py | 3 +- .../models/mask_heads/fused_semantic_head.py | 3 +- mmdet/models/mask_heads/grid_head.py | 3 +- mmdet/models/mask_heads/htc_mask_head.py | 3 +- mmdet/models/necks/bfp.py | 4 +- mmdet/models/necks/fpn.py | 3 +- mmdet/models/necks/fpn_carafe.py | 3 +- mmdet/models/necks/hrfpn.py | 3 +- mmdet/models/necks/nas_fpn.py | 3 +- mmdet/models/necks/pafpn.py | 2 +- mmdet/models/utils/res_layer.py | 3 +- mmdet/ops/__init__.py | 10 +- mmdet/ops/activation.py | 38 ----- mmdet/ops/carafe/carafe.py | 3 +- mmdet/ops/conv.py | 39 ------ mmdet/ops/conv_module.py | 132 ------------------ mmdet/ops/conv_ws.py | 2 + mmdet/ops/dcn/deform_conv.py | 3 + mmdet/ops/non_local.py | 4 +- mmdet/ops/norm.py | 55 -------- mmdet/ops/scale.py | 15 -- mmdet/ops/upsample.py | 79 ----------- mmdet/ops/wrappers.py | 2 + requirements/runtime.txt | 2 +- setup.py | 4 +- 41 files changed, 52 insertions(+), 418 deletions(-) delete mode 100644 mmdet/ops/activation.py delete mode 100644 mmdet/ops/conv.py delete mode 100644 mmdet/ops/conv_module.py delete mode 100644 mmdet/ops/norm.py delete mode 100644 mmdet/ops/scale.py delete mode 100644 mmdet/ops/upsample.py diff --git a/mmdet/core/bbox/assigners/center_region_assigner.py b/mmdet/core/bbox/assigners/center_region_assigner.py index 1dc205e0..91f8c014 100644 --- a/mmdet/core/bbox/assigners/center_region_assigner.py +++ b/mmdet/core/bbox/assigners/center_region_assigner.py @@ -68,7 +68,7 @@ def bboxes_area(bboxes): return areas -@BBOX_ASSIGNERS.register_module +@BBOX_ASSIGNERS.register_module() class CenterRegionAssigner(BaseAssigner): """Assign pixels at the center region of a bbox as positive. diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py index 330a3170..16ee7f77 100644 --- a/mmdet/core/bbox/coder/tblr_bbox_coder.py +++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py @@ -4,7 +4,7 @@ from ..builder import BBOX_CODERS from .base_bbox_coder import BaseBBoxCoder -@BBOX_CODERS.register_module +@BBOX_CODERS.register_module() class TBLRBBoxCoder(BaseBBoxCoder): """TBLR BBox coder diff --git a/mmdet/models/anchor_heads/atss_head.py b/mmdet/models/anchor_heads/atss_head.py index ba0f1f2f..8af8ce04 100644 --- a/mmdet/models/anchor_heads/atss_head.py +++ b/mmdet/models/anchor_heads/atss_head.py @@ -1,12 +1,11 @@ import torch import torch.distributed as dist import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler, force_fp32, images_to_levels, multi_apply, multiclass_nms, unmap) -from mmdet.ops import ConvModule, Scale from ..builder import HEADS, build_loss from .anchor_head import AnchorHead diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py index ffa0fd5c..5445aab7 100644 --- a/mmdet/models/anchor_heads/fcos_head.py +++ b/mmdet/models/anchor_heads/fcos_head.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init from mmdet.core import distance2bbox, force_fp32, multi_apply, multiclass_nms -from mmdet.ops import ConvModule, Scale from ..builder import HEADS, build_loss INF = 1e8 diff --git a/mmdet/models/anchor_heads/fovea_head.py b/mmdet/models/anchor_heads/fovea_head.py index 73a88fed..2f80cb8e 100644 --- a/mmdet/models/anchor_heads/fovea_head.py +++ b/mmdet/models/anchor_heads/fovea_head.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init from mmdet.core import multi_apply, multiclass_nms -from mmdet.ops import ConvModule, DeformConv +from mmdet.ops import DeformConv from ..builder import HEADS, build_loss INF = 1e8 diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index f3577fb7..7da2d8ec 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -9,7 +9,7 @@ from ..losses.utils import weight_reduce_loss from .retina_head import RetinaHead -@HEADS.register_module +@HEADS.register_module() class FSAFHead(RetinaHead): """FSAF anchor-free head used in [1]. diff --git a/mmdet/models/anchor_heads/ga_retina_head.py b/mmdet/models/anchor_heads/ga_retina_head.py index db2639f9..3545172e 100644 --- a/mmdet/models/anchor_heads/ga_retina_head.py +++ b/mmdet/models/anchor_heads/ga_retina_head.py @@ -1,7 +1,7 @@ import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init -from mmdet.ops import ConvModule, MaskedConv2d +from mmdet.ops import MaskedConv2d from ..builder import HEADS from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead diff --git a/mmdet/models/anchor_heads/reppoints_head.py b/mmdet/models/anchor_heads/reppoints_head.py index 7b18f022..2a95da43 100644 --- a/mmdet/models/anchor_heads/reppoints_head.py +++ b/mmdet/models/anchor_heads/reppoints_head.py @@ -1,11 +1,11 @@ import numpy as np import torch import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init from mmdet.core import (PointGenerator, build_assigner, build_sampler, images_to_levels, multi_apply, multiclass_nms, unmap) -from mmdet.ops import ConvModule, DeformConv +from mmdet.ops import DeformConv from ..builder import HEADS, build_loss diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py index 55ec4dbb..871a92d2 100644 --- a/mmdet/models/anchor_heads/retina_head.py +++ b/mmdet/models/anchor_heads/retina_head.py @@ -1,7 +1,6 @@ import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init -from mmdet.ops import ConvModule from ..builder import HEADS from .anchor_head import AnchorHead diff --git a/mmdet/models/anchor_heads/retina_sepbn_head.py b/mmdet/models/anchor_heads/retina_sepbn_head.py index 4e93caf1..1ec5e980 100644 --- a/mmdet/models/anchor_heads/retina_sepbn_head.py +++ b/mmdet/models/anchor_heads/retina_sepbn_head.py @@ -1,7 +1,6 @@ import torch.nn as nn -from mmcv.cnn import bias_init_with_prob, normal_init +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init -from mmdet.ops import ConvModule from ..builder import HEADS from .anchor_head import AnchorHead diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py index 06efe578..da488019 100644 --- a/mmdet/models/backbones/hrnet.py +++ b/mmdet/models/backbones/hrnet.py @@ -1,9 +1,9 @@ import torch.nn as nn -from mmcv.cnn import constant_init, kaiming_init +from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, + kaiming_init) from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm -from mmdet.ops import build_conv_layer, build_norm_layer from mmdet.utils import get_root_logger from ..builder import BACKBONES from .resnet import BasicBlock, Bottleneck diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index ad3941ed..32cf180a 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -1,10 +1,11 @@ import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn import constant_init, kaiming_init +from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, + kaiming_init) from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm -from mmdet.ops import build_conv_layer, build_norm_layer, build_plugin_layer +from mmdet.ops import build_plugin_layer from mmdet.utils import get_root_logger from ..builder import BACKBONES from ..utils import ResLayer diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py index a882081d..dc8c7eba 100644 --- a/mmdet/models/backbones/resnext.py +++ b/mmdet/models/backbones/resnext.py @@ -1,6 +1,7 @@ import math -from mmdet.ops import build_conv_layer, build_norm_layer +from mmcv.cnn import build_conv_layer, build_norm_layer + from ..builder import BACKBONES from ..utils import ResLayer from .resnet import Bottleneck as _Bottleneck diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py index 244092ab..d57be855 100644 --- a/mmdet/models/bbox_heads/convfc_bbox_head.py +++ b/mmdet/models/bbox_heads/convfc_bbox_head.py @@ -1,6 +1,6 @@ import torch.nn as nn +from mmcv.cnn import ConvModule -from mmdet.ops import ConvModule from ..builder import HEADS from .bbox_head import BBoxHead diff --git a/mmdet/models/bbox_heads/double_bbox_head.py b/mmdet/models/bbox_heads/double_bbox_head.py index 3289922d..9a3a1165 100644 --- a/mmdet/models/bbox_heads/double_bbox_head.py +++ b/mmdet/models/bbox_heads/double_bbox_head.py @@ -1,7 +1,6 @@ import torch.nn as nn -from mmcv.cnn.weight_init import normal_init, xavier_init +from mmcv.cnn import ConvModule, normal_init, xavier_init -from mmdet.ops import ConvModule from ..backbones.resnet import Bottleneck from ..builder import HEADS from .bbox_head import BBoxHead diff --git a/mmdet/models/detectors/fsaf.py b/mmdet/models/detectors/fsaf.py index 40b18521..b315794e 100644 --- a/mmdet/models/detectors/fsaf.py +++ b/mmdet/models/detectors/fsaf.py @@ -2,7 +2,7 @@ from ..builder import DETECTORS from .single_stage import SingleStageDetector -@DETECTORS.register_module +@DETECTORS.register_module() class FSAF(SingleStageDetector): def __init__(self, diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py index d8bb7cf1..0e808490 100644 --- a/mmdet/models/mask_heads/fcn_mask_head.py +++ b/mmdet/models/mask_heads/fcn_mask_head.py @@ -3,10 +3,11 @@ import pycocotools.mask as mask_util import torch import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_upsample_layer from torch.nn.modules.utils import _pair from mmdet.core import auto_fp16, force_fp32, mask_target -from mmdet.ops import Conv2d, ConvModule, build_upsample_layer +from mmdet.ops import Conv2d from mmdet.ops.carafe import CARAFEPack from ..builder import HEADS, build_loss diff --git a/mmdet/models/mask_heads/fused_semantic_head.py b/mmdet/models/mask_heads/fused_semantic_head.py index dd671ad4..99e775a3 100644 --- a/mmdet/models/mask_heads/fused_semantic_head.py +++ b/mmdet/models/mask_heads/fused_semantic_head.py @@ -1,9 +1,8 @@ import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import kaiming_init +from mmcv.cnn import ConvModule, kaiming_init from mmdet.core import auto_fp16, force_fp32 -from mmdet.ops import ConvModule from ..builder import HEADS diff --git a/mmdet/models/mask_heads/grid_head.py b/mmdet/models/mask_heads/grid_head.py index 8678f105..34855597 100644 --- a/mmdet/models/mask_heads/grid_head.py +++ b/mmdet/models/mask_heads/grid_head.py @@ -2,9 +2,8 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import kaiming_init, normal_init +from mmcv.cnn import ConvModule, kaiming_init, normal_init -from mmdet.ops import ConvModule from ..builder import HEADS, build_loss diff --git a/mmdet/models/mask_heads/htc_mask_head.py b/mmdet/models/mask_heads/htc_mask_head.py index 91472f3b..a8b2812b 100644 --- a/mmdet/models/mask_heads/htc_mask_head.py +++ b/mmdet/models/mask_heads/htc_mask_head.py @@ -1,4 +1,5 @@ -from mmdet.ops import ConvModule +from mmcv.cnn import ConvModule + from ..builder import HEADS from .fcn_mask_head import FCNMaskHead diff --git a/mmdet/models/necks/bfp.py b/mmdet/models/necks/bfp.py index b128b76f..ed4a2c3e 100644 --- a/mmdet/models/necks/bfp.py +++ b/mmdet/models/necks/bfp.py @@ -1,8 +1,8 @@ import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import xavier_init +from mmcv.cnn import ConvModule, xavier_init -from mmdet.ops import ConvModule, NonLocal2D +from mmdet.ops import NonLocal2D from ..builder import NECKS diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py index f5c45d0b..da508652 100644 --- a/mmdet/models/necks/fpn.py +++ b/mmdet/models/necks/fpn.py @@ -1,9 +1,8 @@ import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import xavier_init +from mmcv.cnn import ConvModule, xavier_init from mmdet.core import auto_fp16 -from mmdet.ops import ConvModule from ..builder import NECKS diff --git a/mmdet/models/necks/fpn_carafe.py b/mmdet/models/necks/fpn_carafe.py index 488c4a82..3180bd1d 100644 --- a/mmdet/models/necks/fpn_carafe.py +++ b/mmdet/models/necks/fpn_carafe.py @@ -1,7 +1,6 @@ import torch.nn as nn -from mmcv.cnn import xavier_init +from mmcv.cnn import ConvModule, build_upsample_layer, xavier_init -from mmdet.ops import ConvModule, build_upsample_layer from mmdet.ops.carafe import CARAFEPack from ..builder import NECKS diff --git a/mmdet/models/necks/hrfpn.py b/mmdet/models/necks/hrfpn.py index c9af7c72..efc60076 100644 --- a/mmdet/models/necks/hrfpn.py +++ b/mmdet/models/necks/hrfpn.py @@ -1,10 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn.weight_init import caffe2_xavier_init +from mmcv.cnn import ConvModule, caffe2_xavier_init from torch.utils.checkpoint import checkpoint -from mmdet.ops import ConvModule from ..builder import NECKS diff --git a/mmdet/models/necks/nas_fpn.py b/mmdet/models/necks/nas_fpn.py index b4cf12d7..16bf01e9 100644 --- a/mmdet/models/necks/nas_fpn.py +++ b/mmdet/models/necks/nas_fpn.py @@ -1,8 +1,7 @@ import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import caffe2_xavier_init +from mmcv.cnn import ConvModule, caffe2_xavier_init -from mmdet.ops import ConvModule from ..builder import NECKS diff --git a/mmdet/models/necks/pafpn.py b/mmdet/models/necks/pafpn.py index 01164553..de14e365 100644 --- a/mmdet/models/necks/pafpn.py +++ b/mmdet/models/necks/pafpn.py @@ -1,8 +1,8 @@ import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import ConvModule from mmdet.core import auto_fp16 -from mmdet.ops import ConvModule from ..builder import NECKS from .fpn import FPN diff --git a/mmdet/models/utils/res_layer.py b/mmdet/models/utils/res_layer.py index 77c0419f..53d61540 100644 --- a/mmdet/models/utils/res_layer.py +++ b/mmdet/models/utils/res_layer.py @@ -1,7 +1,6 @@ +from mmcv.cnn import build_conv_layer, build_norm_layer from torch import nn as nn -from mmdet.ops import build_conv_layer, build_norm_layer - class ResLayer(nn.Sequential): """ResLayer to build ResNet style backbone. diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py index 7a248b3b..32983ccd 100644 --- a/mmdet/ops/__init__.py +++ b/mmdet/ops/__init__.py @@ -1,6 +1,4 @@ from .context_block import ContextBlock -from .conv import build_conv_layer -from .conv_module import ConvModule from .conv_ws import ConvWS2d, conv_ws_2d from .dcn import (DeformConv, DeformConvPack, DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformConv, @@ -10,13 +8,10 @@ from .generalized_attention import GeneralizedAttention from .masked_conv import MaskedConv2d from .nms import batched_nms, nms, soft_nms from .non_local import NonLocal2D -from .norm import build_norm_layer from .plugin import build_plugin_layer from .roi_align import RoIAlign, roi_align from .roi_pool import RoIPool, roi_pool -from .scale import Scale from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss -from .upsample import build_upsample_layer from .utils import get_compiler_version, get_compiling_cuda_version from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d @@ -27,8 +22,7 @@ __all__ = [ 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', 'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss', 'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D', - 'get_compiler_version', 'get_compiling_cuda_version', 'build_conv_layer', - 'ConvModule', 'ConvWS2d', 'conv_ws_2d', 'build_norm_layer', 'Scale', - 'build_upsample_layer', 'build_plugin_layer', 'batched_nms', 'Conv2d', + 'get_compiler_version', 'get_compiling_cuda_version', 'ConvWS2d', + 'conv_ws_2d', 'build_plugin_layer', 'batched_nms', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'Linear' ] diff --git a/mmdet/ops/activation.py b/mmdet/ops/activation.py deleted file mode 100644 index 53b72b2a..00000000 --- a/mmdet/ops/activation.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch.nn as nn - -activation_cfg = { - # layer_abbreviation: module - 'ReLU': nn.ReLU, - 'LeakyReLU': nn.LeakyReLU, - 'PReLU': nn.PReLU, - 'RReLU': nn.RReLU, - 'ReLU6': nn.ReLU6, - 'SELU': nn.SELU, - 'CELU': nn.CELU -} - - -def build_activation_layer(cfg): - """ Build activation layer - - Args: - cfg (dict): cfg should contain: - type (str): Identify activation layer type. - layer args: args needed to instantiate a activation layer. - - Returns: - layer (nn.Module): Created activation layer - """ - assert isinstance(cfg, dict) and 'type' in cfg - cfg_ = cfg.copy() - - layer_type = cfg_.pop('type') - if layer_type not in activation_cfg: - raise KeyError(f'Unrecognized activation type {layer_type}') - else: - activation = activation_cfg[layer_type] - if activation is None: - raise NotImplementedError - - layer = activation(**cfg_) - return layer diff --git a/mmdet/ops/carafe/carafe.py b/mmdet/ops/carafe/carafe.py index d8a15c3e..73c14dfc 100644 --- a/mmdet/ops/carafe/carafe.py +++ b/mmdet/ops/carafe/carafe.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import normal_init, xavier_init +from mmcv.cnn import UPSAMPLE_LAYERS, normal_init, xavier_init from torch.autograd import Function from torch.nn.modules.module import Module @@ -157,6 +157,7 @@ class CARAFE(Module): self.group_size, self.scale_factor) +@UPSAMPLE_LAYERS.register_module(name='carafe') class CARAFEPack(nn.Module): """ A unified package of CARAFE upsampler that contains: 1) channel compressor 2) content encoder 3) CARAFE op diff --git a/mmdet/ops/conv.py b/mmdet/ops/conv.py deleted file mode 100644 index 8316001d..00000000 --- a/mmdet/ops/conv.py +++ /dev/null @@ -1,39 +0,0 @@ -from .conv_ws import ConvWS2d -from .dcn import DeformConvPack, ModulatedDeformConvPack -from .wrappers import Conv2d - -conv_cfg = { - 'Conv': Conv2d, - 'ConvWS': ConvWS2d, - 'DCN': DeformConvPack, - 'DCNv2': ModulatedDeformConvPack, - # TODO: octave conv -} - - -def build_conv_layer(cfg, *args, **kwargs): - """ Build convolution layer - - Args: - cfg (None or dict): cfg should contain: - type (str): identify conv layer type. - layer args: args needed to instantiate a conv layer. - - Returns: - layer (nn.Module): created conv layer - """ - if cfg is None: - cfg_ = dict(type='Conv') - else: - assert isinstance(cfg, dict) and 'type' in cfg - cfg_ = cfg.copy() - - layer_type = cfg_.pop('type') - if layer_type not in conv_cfg: - raise KeyError(f'Unrecognized norm type {layer_type}') - else: - conv_layer = conv_cfg[layer_type] - - layer = conv_layer(*args, **kwargs, **cfg_) - - return layer diff --git a/mmdet/ops/conv_module.py b/mmdet/ops/conv_module.py deleted file mode 100644 index 975e00f6..00000000 --- a/mmdet/ops/conv_module.py +++ /dev/null @@ -1,132 +0,0 @@ -import warnings - -import torch.nn as nn -from mmcv.cnn import constant_init, kaiming_init - -from .activation import build_activation_layer -from .conv import build_conv_layer -from .norm import build_norm_layer - - -class ConvModule(nn.Module): - """A conv block that contains conv/norm/activation layers. - - Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int or tuple[int]): Same as nn.Conv2d. - padding (int or tuple[int]): Same as nn.Conv2d. - dilation (int or tuple[int]): Same as nn.Conv2d. - groups (int): Same as nn.Conv2d. - bias (bool or str): If specified as `auto`, it will be decided by the - norm_cfg. Bias will be set as True if norm_cfg is None, otherwise - False. - conv_cfg (dict): Config dict for convolution layer. - norm_cfg (dict): Config dict for normalization layer. - act_cfg (dict): Config dict for activation layer, "relu" by default. - inplace (bool): Whether to use inplace mode for activation. - order (tuple[str]): The order of conv/norm/activation layers. It is a - sequence of "conv", "norm" and "act". Examples are - ("conv", "norm", "act") and ("act", "conv", "norm"). - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias='auto', - conv_cfg=None, - norm_cfg=None, - act_cfg=dict(type='ReLU'), - inplace=True, - order=('conv', 'norm', 'act')): - super(ConvModule, self).__init__() - assert conv_cfg is None or isinstance(conv_cfg, dict) - assert norm_cfg is None or isinstance(norm_cfg, dict) - assert act_cfg is None or isinstance(act_cfg, dict) - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg - self.inplace = inplace - self.order = order - assert isinstance(self.order, tuple) and len(self.order) == 3 - assert set(order) == set(['conv', 'norm', 'act']) - - self.with_norm = norm_cfg is not None - self.with_activation = act_cfg is not None - # if the conv layer is before a norm layer, bias is unnecessary. - if bias == 'auto': - bias = False if self.with_norm else True - self.with_bias = bias - - if self.with_norm and self.with_bias: - warnings.warn('ConvModule has norm and bias at the same time') - - # build convolution layer - self.conv = build_conv_layer( - conv_cfg, - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias) - # export the attributes of self.conv to a higher level for convenience - self.in_channels = self.conv.in_channels - self.out_channels = self.conv.out_channels - self.kernel_size = self.conv.kernel_size - self.stride = self.conv.stride - self.padding = self.conv.padding - self.dilation = self.conv.dilation - self.transposed = self.conv.transposed - self.output_padding = self.conv.output_padding - self.groups = self.conv.groups - - # build normalization layers - if self.with_norm: - # norm layer is after conv layer - if order.index('norm') > order.index('conv'): - norm_channels = out_channels - else: - norm_channels = in_channels - self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) - self.add_module(self.norm_name, norm) - - # build activation layer - if self.with_activation: - act_cfg_ = act_cfg.copy() - act_cfg_.setdefault('inplace', inplace) - self.activate = build_activation_layer(act_cfg_) - - # Use msra init by default - self.init_weights() - - @property - def norm(self): - return getattr(self, self.norm_name) - - def init_weights(self): - if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': - nonlinearity = 'leaky_relu' - else: - nonlinearity = 'relu' - kaiming_init(self.conv, nonlinearity=nonlinearity) - if self.with_norm: - constant_init(self.norm, 1, bias=0) - - def forward(self, x, activate=True, norm=True): - for layer in self.order: - if layer == 'conv': - x = self.conv(x) - elif layer == 'norm' and norm and self.with_norm: - x = self.norm(x) - elif layer == 'act' and activate and self.with_activation: - x = self.activate(x) - return x diff --git a/mmdet/ops/conv_ws.py b/mmdet/ops/conv_ws.py index 5ccd735f..7704683f 100644 --- a/mmdet/ops/conv_ws.py +++ b/mmdet/ops/conv_ws.py @@ -1,5 +1,6 @@ import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import CONV_LAYERS def conv_ws_2d(input, @@ -18,6 +19,7 @@ def conv_ws_2d(input, return F.conv2d(input, weight, bias, stride, padding, dilation, groups) +@CONV_LAYERS.register_module('ConvWS') class ConvWS2d(nn.Conv2d): def __init__(self, diff --git a/mmdet/ops/dcn/deform_conv.py b/mmdet/ops/dcn/deform_conv.py index 42245d5c..7d08ceca 100644 --- a/mmdet/ops/dcn/deform_conv.py +++ b/mmdet/ops/dcn/deform_conv.py @@ -3,6 +3,7 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import CONV_LAYERS from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single @@ -252,6 +253,7 @@ class DeformConv(nn.Module): return out +@CONV_LAYERS.register_module('DCN') class DeformConvPack(DeformConv): """A Deformable Conv Encapsulation that acts as normal Conv layers. @@ -371,6 +373,7 @@ class ModulatedDeformConv(nn.Module): self.groups, self.deformable_groups) +@CONV_LAYERS.register_module('DCNv2') class ModulatedDeformConvPack(ModulatedDeformConv): """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. diff --git a/mmdet/ops/non_local.py b/mmdet/ops/non_local.py index 3ac4bf9f..3630eb2f 100644 --- a/mmdet/ops/non_local.py +++ b/mmdet/ops/non_local.py @@ -1,8 +1,6 @@ import torch import torch.nn as nn -from mmcv.cnn import constant_init, normal_init - -from .conv_module import ConvModule +from mmcv.cnn import ConvModule, constant_init, normal_init class NonLocal2D(nn.Module): diff --git a/mmdet/ops/norm.py b/mmdet/ops/norm.py deleted file mode 100644 index 99fc9311..00000000 --- a/mmdet/ops/norm.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch.nn as nn - -norm_cfg = { - # format: layer_type: (abbreviation, module) - 'BN': ('bn', nn.BatchNorm2d), - 'SyncBN': ('bn', nn.SyncBatchNorm), - 'GN': ('gn', nn.GroupNorm), - # and potentially 'SN' -} - - -def build_norm_layer(cfg, num_features, postfix=''): - """ Build normalization layer - - Args: - cfg (dict): cfg should contain: - type (str): identify norm layer type. - layer args: args needed to instantiate a norm layer. - requires_grad (bool): [optional] whether stop gradient updates - num_features (int): number of channels from input. - postfix (int, str): appended into norm abbreviation to - create named layer. - - Returns: - name (str): abbreviation + postfix - layer (nn.Module): created norm layer - """ - assert isinstance(cfg, dict) and 'type' in cfg - cfg_ = cfg.copy() - - layer_type = cfg_.pop('type') - if layer_type not in norm_cfg: - raise KeyError(f'Unrecognized norm type {layer_type}') - else: - abbr, norm_layer = norm_cfg[layer_type] - if norm_layer is None: - raise NotImplementedError - - assert isinstance(postfix, (int, str)) - name = abbr + str(postfix) - - requires_grad = cfg_.pop('requires_grad', True) - cfg_.setdefault('eps', 1e-5) - if layer_type != 'GN': - layer = norm_layer(num_features, **cfg_) - if layer_type == 'SyncBN': - layer._specify_ddp_gpu_num(1) - else: - assert 'num_groups' in cfg_ - layer = norm_layer(num_channels=num_features, **cfg_) - - for param in layer.parameters(): - param.requires_grad = requires_grad - - return name, layer diff --git a/mmdet/ops/scale.py b/mmdet/ops/scale.py deleted file mode 100644 index 2461af8a..00000000 --- a/mmdet/ops/scale.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -import torch.nn as nn - - -class Scale(nn.Module): - """ - A learnable scale parameter - """ - - def __init__(self, scale=1.0): - super(Scale, self).__init__() - self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) - - def forward(self, x): - return x * self.scale diff --git a/mmdet/ops/upsample.py b/mmdet/ops/upsample.py deleted file mode 100644 index 2e405a03..00000000 --- a/mmdet/ops/upsample.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F -from mmcv.cnn import xavier_init - -from .carafe import CARAFEPack -from .wrappers import ConvTranspose2d - - -class PixelShufflePack(nn.Module): - """ Pixel Shuffle upsample layer - - Args: - in_channels (int): Number of input channels - out_channels (int): Number of output channels - scale_factor (int): Upsample ratio - upsample_kernel (int): Kernel size of Conv layer to expand the channels - - Returns: - upsampled feature map - """ - - def __init__(self, in_channels, out_channels, scale_factor, - upsample_kernel): - super(PixelShufflePack, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.scale_factor = scale_factor - self.upsample_kernel = upsample_kernel - self.upsample_conv = nn.Conv2d( - self.in_channels, - self.out_channels * scale_factor * scale_factor, - self.upsample_kernel, - padding=(self.upsample_kernel - 1) // 2) - self.init_weights() - - def init_weights(self): - xavier_init(self.upsample_conv, distribution='uniform') - - def forward(self, x): - x = self.upsample_conv(x) - x = F.pixel_shuffle(x, self.scale_factor) - return x - - -upsample_cfg = { - # layer_abbreviation: module - 'nearest': nn.Upsample, - 'bilinear': nn.Upsample, - 'deconv': ConvTranspose2d, - 'pixel_shuffle': PixelShufflePack, - 'carafe': CARAFEPack -} - - -def build_upsample_layer(cfg): - """ Build upsample layer - - Args: - cfg (dict): cfg should contain: - type (str): Identify upsample layer type. - upsample ratio (int): Upsample ratio - layer args: args needed to instantiate a upsample layer. - - Returns: - layer (nn.Module): Created upsample layer - """ - assert isinstance(cfg, dict) and 'type' in cfg - cfg_ = cfg.copy() - - layer_type = cfg_.pop('type') - if layer_type not in upsample_cfg: - raise KeyError(f'Unrecognized upsample type {layer_type}') - else: - upsample = upsample_cfg[layer_type] - if upsample is None: - raise NotImplementedError - - layer = upsample(**cfg_) - return layer diff --git a/mmdet/ops/wrappers.py b/mmdet/ops/wrappers.py index a5f9fb9f..8177fee8 100644 --- a/mmdet/ops/wrappers.py +++ b/mmdet/ops/wrappers.py @@ -9,6 +9,7 @@ import math import torch import torch.nn as nn +from mmcv.cnn import CONV_LAYERS from torch.nn.modules.utils import _pair @@ -25,6 +26,7 @@ class NewEmptyTensorOp(torch.autograd.Function): return NewEmptyTensorOp.apply(grad, shape), None +@CONV_LAYERS.register_module('Conv', force=True) class Conv2d(nn.Conv2d): def forward(self, x): diff --git a/requirements/runtime.txt b/requirements/runtime.txt index f6d104e3..13194bca 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,5 +1,5 @@ matplotlib -mmcv>=0.4.4 +mmcv>=0.5.0 numpy # need older pillow until torchvision is fixed Pillow<=6.2.2 diff --git a/setup.py b/setup.py index 390d08a1..29b5e1e2 100755 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ def readme(): return content -MAJOR = 1 -MINOR = 1 +MAJOR = 2 +MINOR = 0 PATCH = 0 SUFFIX = '' if PATCH != '': -- GitLab