diff --git a/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py b/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c22dc1eda62c25107fa85856a987d8903ffc8e --- /dev/null +++ b/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py @@ -0,0 +1,12 @@ +_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' + +model = dict( + backbone=dict( + type='CBResNet', + cb_del_stages=1, + cb_inplanes=[64, 256, 512, 1024, 2048], + ), + neck=dict( + type='CBFPN', + ) +) \ No newline at end of file diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index 5d35e48ab7b7c3118f2e58def8b2cfe0181f0f39..7ad56f448c308f5389edf2df20462ff4fdd4e385 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -11,9 +11,10 @@ from .resnext import ResNeXt from .ssd_vgg import SSDVGG from .trident_resnet import TridentResNet from .swin_transformer import SwinTransformer +from .cbnet import CBResNet, CBRes2Net, CBSwinTransformer __all__ = [ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', - 'ResNeSt', 'TridentResNet', 'SwinTransformer' + 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'CBResNet', 'CBRes2Net', 'CBSwinTransformer' ] diff --git a/mmdet/models/backbones/cbnet.py b/mmdet/models/backbones/cbnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a6c7cd390ec7dd620318ac283357ca66f4652c --- /dev/null +++ b/mmdet/models/backbones/cbnet.py @@ -0,0 +1,381 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmcv.cnn import constant_init +from mmdet.utils import get_root_logger +from ..builder import BACKBONES +from .resnet import ResNet, build_norm_layer, _BatchNorm +from .res2net import Res2Net +from .swin_transformer import SwinTransformer + +from mmcv.runner import BaseModule +''' +For CNN +''' +class _CBSubnet(BaseModule): + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem and hasattr(self, 'stem'): + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + elif hasattr(self, 'conv1'): + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if not hasattr(self, f'layer{i}'): + continue + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def del_layers(self, del_stages): + self.del_stages = del_stages + if self.del_stages>=0: + if self.deep_stem: + del self.stem + else: + del self.conv1 + + for i in range(1, self.del_stages+1): + delattr(self, f'layer{i}') + + def forward(self, x, cb_feats=None, pre_outs=None): + """Forward function.""" + spatial_info = [] + outs = [] + + if self.deep_stem and hasattr(self, 'stem'): + x = self.stem(x) + x = self.maxpool(x) + elif hasattr(self, 'conv1'): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + else: + x = pre_outs[0] + outs.append(x) + + for i, layer_name in enumerate(self.res_layers): + if hasattr(self, layer_name): + res_layer = getattr(self, layer_name) + spatial_info.append(x.shape[2:]) + if cb_feats is not None: + x = x + cb_feats[i] + x = res_layer(x) + else: + x = pre_outs[i+1] + outs.append(x) + return tuple(outs), spatial_info + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self._freeze_stages() + +class _ResNet(_CBSubnet, ResNet): + def __init__(self, **kwargs): + _CBSubnet.__init__(self) + ResNet.__init__(self, **kwargs) + +class _Res2Net(_CBSubnet, Res2Net): + def __init__(self, **kwargs): + _CBSubnet.__init__(self) + Res2Net.__init__(self, **kwargs) + +class _CBNet(BaseModule): + def _freeze_stages(self): + for m in self.cb_modules: + m._freeze_stages() + + def init_cb_weights(self): + raise NotImplementedError + + def init_weights(self): + self.init_cb_weights() + for m in self.cb_modules: + m.init_weights() + + def _get_cb_feats(self, feats, spatial_info): + raise NotImplementedError + + def forward(self, x): + outs_list = [] + for i, module in enumerate(self.cb_modules): + if i == 0: + pre_outs, spatial_info = module(x) + else: + pre_outs, spatial_info = module(x, cb_feats, pre_outs) + + outs = [pre_outs[i+1] for i in self.out_indices] + outs_list.append(tuple(outs)) + + if i < len(self.cb_modules)-1: + cb_feats = self._get_cb_feats(pre_outs, spatial_info) + return tuple(outs_list) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + for m in self.cb_modules: + m.train(mode=mode) + self._freeze_stages() + for m in self.cb_linears.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + +class _CBResNet(_CBNet): + def __init__(self, net, cb_inplanes, cb_zero_init=True, cb_del_stages=0, **kwargs): + super(_CBResNet, self).__init__() + self.cb_zero_init = cb_zero_init + self.cb_del_stages = cb_del_stages + + self.cb_modules = nn.ModuleList() + for cb_idx in range(2): + cb_module = net(**kwargs) + if cb_idx > 0: + cb_module.del_layers(self.cb_del_stages) + self.cb_modules.append(cb_module) + self.out_indices = self.cb_modules[0].out_indices + + self.cb_linears = nn.ModuleList() + self.num_layers = len(self.cb_modules[0].stage_blocks) + norm_cfg = self.cb_modules[0].norm_cfg + for i in range(self.num_layers): + linears = nn.ModuleList() + if i >= self.cb_del_stages: + jrange = 4 - i + for j in range(jrange): + linears.append( + nn.Sequential( + nn.Conv2d(cb_inplanes[i + j + 1], cb_inplanes[i], 1, bias=False), + build_norm_layer(norm_cfg, cb_inplanes[i])[1] + ) + ) + + self.cb_linears.append(linears) + + def init_cb_weights(self): + if self.cb_zero_init: + for ls in self.cb_linears: + for m in ls: + if isinstance(m, nn.Sequential): + constant_init(m[-1], 0) + else: + constant_init(m, 0) + + def _get_cb_feats(self, feats, spatial_info): + cb_feats = [] + for i in range(self.num_layers): + if i >= self.cb_del_stages: + h, w = spatial_info[i] + feeds = [] + jrange = 4 - i + for j in range(jrange): + tmp = self.cb_linears[i][j](feats[j + i + 1]) + tmp = F.interpolate(tmp, size=(h, w), mode='nearest') + feeds.append(tmp) + feed = torch.sum(torch.stack(feeds,dim=-1), dim=-1) + else: + feed = 0 + cb_feats.append(feed) + + return cb_feats + + +@BACKBONES.register_module() +class CBResNet(_CBResNet): + def __init__(self, **kwargs): + super().__init__(net=_ResNet, **kwargs) + +@BACKBONES.register_module() +class CBRes2Net(_CBResNet): + def __init__(self, **kwargs): + super().__init__(net=_Res2Net, **kwargs) + + +''' +For Swin Transformer +''' +class _SwinTransformer(SwinTransformer): + def _freeze_stages(self): + if self.frozen_stages >= 0 and hasattr(self, 'patch_embed'): + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + if m is None: + continue + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def del_layers(self, del_stages): + self.del_stages = del_stages + if self.del_stages>=0: + del self.patch_embed + + if self.del_stages >=1 and self.ape: + del self.absolute_pos_embed + + for i in range(0, self.del_stages - 1): + self.layers[i] = None + + def forward(self, x, cb_feats=None, pre_tmps=None): + """Forward function.""" + outs = [] + tmps = [] + if hasattr(self, 'patch_embed'): + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + tmps.append((x, Wh, Ww)) + else: + x, Wh, Ww = pre_tmps[0] + + for i in range(self.num_layers): + layer = self.layers[i] + if layer is None: + x_out, H, W, x, Wh, Ww = pre_tmps[i+1] + else: + if cb_feats is not None: + x = x + cb_feats[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + tmps.append((x_out, H, W, x, Wh, Ww)) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, + self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs), tmps + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(_SwinTransformer, self).train(mode) + self._freeze_stages() + + +@BACKBONES.register_module() +class CBSwinTransformer(BaseModule): + def __init__(self, embed_dim=96, cb_zero_init=True, cb_del_stages=1, **kwargs): + super(CBSwinTransformer, self).__init__() + self.cb_zero_init = cb_zero_init + self.cb_del_stages = cb_del_stages + self.cb_modules = nn.ModuleList() + for cb_idx in range(2): + cb_module = _SwinTransformer(embed_dim=embed_dim, **kwargs) + if cb_idx > 0: + cb_module.del_layers(cb_del_stages) + self.cb_modules.append(cb_module) + + self.num_layers = self.cb_modules[0].num_layers + + cb_inplanes = [embed_dim * 2 ** i for i in range(self.num_layers)] + + self.cb_linears = nn.ModuleList() + for i in range(self.num_layers): + linears = nn.ModuleList() + if i >= self.cb_del_stages-1: + jrange = 4 - i + for j in range(jrange): + if cb_inplanes[i + j] != cb_inplanes[i]: + layer = nn.Conv2d(cb_inplanes[i + j], cb_inplanes[i], 1) + else: + layer = nn.Identity() + linears.append(layer) + self.cb_linears.append(linears) + + def _freeze_stages(self): + for m in self.cb_modules: + m._freeze_stages() + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + # constant_init(self.cb_linears, 0) + if self.cb_zero_init: + for ls in self.cb_linears: + for m in ls: + constant_init(m, 0) + + for m in self.cb_modules: + m.init_weights() + + def spatial_interpolate(self, x, H, W): + B, C = x.shape[:2] + if H != x.shape[2] or W != x.shape[3]: + # B, C, size[0], size[1] + x = F.interpolate(x, size=(H, W), mode='nearest') + x = x.view(B, C, -1).permute(0, 2, 1).contiguous() # B, T, C + return x + + def _get_cb_feats(self, feats, tmps): + cb_feats = [] + Wh, Ww = tmps[0][-2:] + for i in range(self.num_layers): + feed = 0 + if i >= self.cb_del_stages-1: + jrange = 4 - i + for j in range(jrange): + tmp = self.cb_linears[i][j](feats[j + i]) + tmp = self.spatial_interpolate(tmp, Wh, Ww) + feed += tmp + cb_feats.append(feed) + Wh, Ww = tmps[i+1][-2:] + + return cb_feats + + def forward(self, x): + outs = [] + for i, module in enumerate(self.cb_modules): + if i == 0: + feats, tmps = module(x) + else: + feats, tmps = module(x, cb_feats, tmps) + + outs.append(feats) + + if i < len(self.cb_modules)-1: + cb_feats = self._get_cb_feats(outs[-1], tmps) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(CBSwinTransformer, self).train(mode) + for m in self.cb_modules: + m.train(mode=mode) + self._freeze_stages() + for m in self.cb_linears.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() \ No newline at end of file diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index 02f833a8a0f538a8c06fef622d1cadc1a1b66ea2..49d1d8f79aa7ab6435fc1b8f2cf81c6452f7c763 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -9,8 +9,9 @@ from .nasfcos_fpn import NASFCOS_FPN from .pafpn import PAFPN from .rfp import RFP from .yolo_neck import YOLOV3Neck +from .cbnet_fpn import CBFPN __all__ = [ 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', - 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG' + 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'CBFPN' ] diff --git a/mmdet/models/necks/cbnet_fpn.py b/mmdet/models/necks/cbnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391