提交 8a0826db 编辑于 作者: suilin0432's avatar suilin0432
浏览文件

,

上级 9f5f790d
......@@ -633,6 +633,7 @@ _C.MODEL.VGG = CN()
_C.MODEL.VGG.DEPTH = 16
_C.MODEL.VGG.OUT_FEATURES = ["plain5"]
_C.MODEL.VGG.CONV5_DILATION = 1
_C.MODEL.VGG.FPN = False
# VGG BOXHEAD 参数设置
_C.MODEL.ROI_BOX_HEAD.DAN_DIM = [4096, 4096]
\ No newline at end of file
......@@ -12,6 +12,7 @@ from .resnet import (
BottleneckBlock,
)
from .vgg16 import build_vgg_backbone
from .vgg_torch import vgg16_bn, build_vgg16
__all__ = [k for k in globals().keys() if not k.startswith("_")]
# TODO can expose more resnet blocks after careful consideration
......@@ -10,6 +10,7 @@ from detectron2.layers import Conv2d, ShapeSpec, get_norm
from .backbone import Backbone
from .build import BACKBONE_REGISTRY
from .resnet import build_resnet_backbone
from .vgg16 import build_vgg_backbone
__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"]
......@@ -56,7 +57,8 @@ class FPN(Backbone):
input_shapes = bottom_up.output_shape()
strides = [input_shapes[f].stride for f in in_features]
in_channels_per_feature = [input_shapes[f].channels for f in in_features]
# print(strides)
# exit()
_assert_strides_are_log2_contiguous(strides)
lateral_convs = []
output_convs = []
......@@ -230,6 +232,19 @@ def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
)
return backbone
@BACKBONE_REGISTRY.register()
def build_vgg_fpn_backbone(cfg, input_shape: ShapeSpec):
bottom_up = build_vgg_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
@BACKBONE_REGISTRY.register()
def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
......
......@@ -123,7 +123,7 @@ class PlainBlock(PlainBlockBase):
class VGG16(Backbone):
def __init__(self, conv5_dilation, freeze_at, num_classes=None, out_features=None):
def __init__(self, conv5_dilation, freeze_at, num_classes=None, out_features=None, fpn=False):
"""
Args:
stem (nn.Module): a stem module
......@@ -142,6 +142,7 @@ class VGG16(Backbone):
self._out_feature_channels = {}
self.stages_and_names = []
self.has_fpn = fpn
name = "plain1"
block = PlainBlock(3, 64, num_conv=2, stride=2, has_pool=True)
......@@ -194,12 +195,15 @@ class VGG16(Backbone):
block.freeze()
name = "plain5"
block = PlainBlock(512, 512, num_conv=3, stride=1, dilation=conv5_dilation, has_pool=False)
has_pool = self.has_fpn
block = PlainBlock(512, 512, num_conv=3, stride=1 if (not has_pool) else 2, dilation=conv5_dilation, has_pool=has_pool)
blocks = [block]
stage = nn.Sequential(*blocks)
self.add_module(name, stage)
self.stages_and_names.append((stage, name))
self._out_feature_strides[name] = 8 if conv5_dilation == 2 else 16
if has_pool:
self._out_feature_strides[name] = self._out_feature_strides[name] * 2
self._out_feature_channels[name] = blocks[-1].out_channels
if freeze_at >= 5:
for block in blocks:
......@@ -239,7 +243,8 @@ def build_vgg_backbone(cfg, input_shape):
conv5_dilation = cfg.MODEL.VGG.CONV5_DILATION
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
out_features = cfg.MODEL.VGG.OUT_FEATURES
fpn = cfg.MODEL.VGG.FPN
# fmt: on
if depth == 16:
return VGG16(conv5_dilation, freeze_at, out_features = out_features)
return VGG16(conv5_dilation, freeze_at, out_features = out_features, fpn=fpn)
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from torchvision.models.utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast
from detectron2.layers import Conv2d, FrozenBatchNorm2d, ShapeSpec
from detectron2.modeling.backbone.backbone import Backbone
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',
]
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
......@@ -22,35 +32,76 @@ model_urls = {
}
class VGG(nn.Module):
class VGG(Backbone):
def __init__(
self,
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
init_weights: bool = True,
freeze_at: int = 2,
out_features = ["plain5"],
out_features_strides = None,
out_features_channels = None,
) -> None:
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
self.out_features_strides = out_features_strides
self.out_features_channels = out_features_channels
# print(self.out_features_strides, self.out_features_channels)
# self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
# self.classifier = nn.Sequential(
# nn.Linear(512 * 7 * 7, 4096),
# nn.ReLU(True),
# nn.Dropout(),
# nn.Linear(4096, 4096),
# nn.ReLU(True),
# nn.Dropout(),
# nn.Linear(4096, num_classes),
# )
self.out_features = out_features
if init_weights:
self._initialize_weights()
self._freeze_backbone(freeze_at)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
outputs = {}
if len(self.features) > 40:
for layer in range(len(self.features)):
x = self.features[layer](x)
if layer == 6 and "plain1" in self.out_features:
outputs["plain1"] = x
elif layer == 13 and "plain2" in self.out_features:
outputs["plain2"] = x
elif layer == 23 and "plain3" in self.out_features:
outputs["plain3"] = x
elif layer == 33 and "plain4" in self.out_features:
outputs["plain4"] = x
elif layer == 42 and "plain5" in self.out_features:
outputs["plain5"] = x
else:
for layer in range(len(self.features)):
x = self.features[layer](x)
if layer == 4 and "plain1" in self.out_features:
outputs["plain1"] = x
elif layer == 9 and "plain2" in self.out_features:
outputs["plain2"] = x
elif layer == 16 and "plain3" in self.out_features:
outputs["plain3"] = x
elif layer == 23 and "plain4" in self.out_features:
outputs["plain4"] = x
elif layer == 29 and "plain5" in self.out_features:
outputs["plain5"] = x
return outputs
def _freeze_backbone(self, freeze_at):
if freeze_at < 0:
return
assert freeze_at in [1, 2, 3, 4, 5]
layer_index = [7, 14, 24, 34, 42]
for layer in range(layer_index[freeze_at - 1]):
for p in self.features[layer].parameters(): p.requires_grad = False
def _initialize_weights(self) -> None:
for m in self.modules():
......@@ -64,14 +115,48 @@ class VGG(nn.Module):
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def output_shape(self):
return {
name: ShapeSpec(
channels=self.out_features_channels[name], stride=self.out_features_strides[name]
)
for name in self.out_features
}
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
layers: List[nn.Module] = []
out_features_strides = {}
out_features_channels = {}
name_template = "plain{}"
stage = 0
strides_num = {
1: 2,
2: 4,
3: 8,
4: 8,
5: 8
}
in_channels = 3
for v in cfg:
if v == 'M':
stage += 1
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
out_features_strides[name_template.format(stage)] = strides_num[stage]
out_features_channels[name_template.format(stage)] = in_channels
elif v == "I":
layers += [Identity()]
out_features_strides[name_template.format(stage+1)] = strides_num[stage+1]
out_features_channels[name_template.format(stage+1)] = in_channels
elif isinstance(v, str) and "-D" in v:
_v = int(v.split('-')[0])
conv2d = nn.Conv2d(in_channels, _v, kernel_size=3, padding=2, dilation=2)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(_v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = _v
else:
v = cast(int, v)
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
......@@ -80,7 +165,10 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
out_features_strides["plain5"] = strides_num[5]
out_features_channels["plain5"] = in_channels
return nn.Sequential(*layers), out_features_strides, out_features_channels
# return nn.Sequential(*layers[:-1])
cfgs: Dict[str, List[Union[str, int]]] = {
......@@ -88,13 +176,25 @@ cfgs: Dict[str, List[Union[str, int]]] = {
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
"VGG16-OICR": [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'I', '512-D', '512-D', '512-D']
}
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
model= VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def _vgg_wsl(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, freeze_at=2, out_features=None, **kwargs: Any) -> VGG:
if pretrained:
kwargs['init_weights'] = False
features, out_features_strides ,out_features_channels = make_layers(cfgs[cfg], batch_norm=batch_norm)
model = VGG(features, out_features, out_features_strides=out_features_strides, out_features_channels=out_features_channels, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
......@@ -102,81 +202,29 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool
return model
def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration 'E') with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
\ No newline at end of file
@BACKBONE_REGISTRY.register()
def vgg16_bn(cfg, input_shape):
depth = cfg.MODEL.VGG.DEPTH
assert depth == 16, "only support vgg16_bn and vgg16"
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
out_features = cfg.MODEL.VGG.OUT_FEATURES
return _vgg_wsl("vgg16_bn", "VGG16-OICR", True, False, True, freeze_at, out_features)
@BACKBONE_REGISTRY.register()
def build_vgg16(cfg, input_shape):
depth = cfg.MODEL.VGG.DEPTH
assert depth == 16, "only support vgg16_bn and vgg16"
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
out_features = cfg.MODEL.VGG.OUT_FEATURES
return _vgg_wsl("vgg16", "VGG16-OICR", False, False, True, freeze_at, out_features)
# def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
# r"""VGG 16-layer model (configuration "D") with batch normalization
# `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
# Args:
# pretrained (bool): If True, returns a model pre-trained on ImageNet
# progress (bool): If True, displays a progress bar of the download to stderr
# """
# return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
\ No newline at end of file
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册