batch_norm.py 9.6 KB
Newer Older
Yuxin Wu's avatar
Yuxin Wu 已提交
1
# Copyright (c) Facebook, Inc. and its affiliates.
facebook-github-bot's avatar
facebook-github-bot 已提交
2
3
4
import logging
import torch
import torch.distributed as dist
5
from fvcore.nn.distributed import differentiable_all_reduce
facebook-github-bot's avatar
facebook-github-bot 已提交
6
from torch import nn
7
from torch.nn import functional as F
facebook-github-bot's avatar
facebook-github-bot 已提交
8

Yuxin Wu's avatar
Yuxin Wu 已提交
9
from detectron2.utils import comm, env
facebook-github-bot's avatar
facebook-github-bot 已提交
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

from .wrappers import BatchNorm2d


class FrozenBatchNorm2d(nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    It contains non-trainable buffers called
    "weight" and "bias", "running_mean", "running_var",
    initialized to perform identity transformation.

    The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
    which are computed from the original four parameters of BN.
    The affine transform `x * weight + bias` will perform the equivalent
    computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
    When loading a backbone model from Caffe2, "running_mean" and "running_var"
    will be left unchanged as identity transformation.

    Other pre-trained backbone models may contain all 4 parameters.

    The forward is implemented by `F.batch_norm(..., training=False)`.
    """

    _version = 3

    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.register_buffer("weight", torch.ones(num_features))
        self.register_buffer("bias", torch.zeros(num_features))
        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var", torch.ones(num_features) - eps)

    def forward(self, x):
46
47
48
49
50
51
52
        if x.requires_grad:
            # When gradients are needed, F.batch_norm will use extra memory
            # because its backward op computes gradients for weight/bias as well.
            scale = self.weight * (self.running_var + self.eps).rsqrt()
            bias = self.bias - self.running_mean * scale
            scale = scale.reshape(1, -1, 1, 1)
            bias = bias.reshape(1, -1, 1, 1)
Yuxin Wu's avatar
Yuxin Wu 已提交
53
54
            out_dtype = x.dtype  # may be half
            return x * scale.to(out_dtype) + bias.to(out_dtype)
55
56
57
58
59
60
61
62
63
64
65
66
        else:
            # When gradients are not needed, F.batch_norm is a single fused op
            # and provide more optimization opportunities.
            return F.batch_norm(
                x,
                self.running_mean,
                self.running_var,
                self.weight,
                self.bias,
                training=False,
                eps=self.eps,
            )
facebook-github-bot's avatar
facebook-github-bot 已提交
67
68
69
70
71
72
73
74
75
76
77
78
79

    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        version = local_metadata.get("version", None)
        if version is None or version < 2:
            # No running_mean/var in early versions
            # This will silent the warnings
            if prefix + "running_mean" not in state_dict:
                state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
            if prefix + "running_var" not in state_dict:
                state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)

Yuxin Wu's avatar
Yuxin Wu 已提交
80
81
82
        # NOTE: if a checkpoint is trained with BatchNorm and loaded (together with
        # version number) to FrozenBatchNorm, running_var will be wrong. One solution
        # is to remove the version number from the checkpoint.
suilin0432's avatar
update    
suilin0432 已提交
83
84
        
        # 暂时注释掉, 不注释掉是训练不了的... 就很离谱, 不知道为什么这里要多余的操作一下
suilin0432's avatar
.    
suilin0432 已提交
85
86
87
88
89
        if version is not None and version < 3:
            logger = logging.getLogger(__name__)
            logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
            # In version < 3, running_var are used without +eps.
            state_dict[prefix + "running_var"] -= self.eps
facebook-github-bot's avatar
facebook-github-bot 已提交
90
91
92
93
94
95
96
97
98
99
100

        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )

    def __repr__(self):
        return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)

    @classmethod
    def convert_frozen_batchnorm(cls, module):
        """
Yuxin Wu's avatar
Yuxin Wu 已提交
101
        Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
facebook-github-bot's avatar
facebook-github-bot 已提交
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

        Args:
            module (torch.nn.Module):

        Returns:
            If module is BatchNorm/SyncBatchNorm, returns a new module.
            Otherwise, in-place convert module and return it.

        Similar to convert_sync_batchnorm in
        https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
        """
        bn_module = nn.modules.batchnorm
        bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
        res = module
        if isinstance(module, bn_module):
            res = cls(module.num_features)
            if module.affine:
                res.weight.data = module.weight.data.clone().detach()
                res.bias.data = module.bias.data.clone().detach()
            res.running_mean.data = module.running_mean.data
122
123
            res.running_var.data = module.running_var.data
            res.eps = module.eps
facebook-github-bot's avatar
facebook-github-bot 已提交
124
125
126
127
128
129
130
131
132
133
134
        else:
            for name, child in module.named_children():
                new_child = cls.convert_frozen_batchnorm(child)
                if new_child is not child:
                    res.add_module(name, new_child)
        return res


def get_norm(norm, out_channels):
    """
    Args:
Yuxin Wu's avatar
Yuxin Wu 已提交
135
136
137
        norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
            or a callable that takes a channel number and returns
            the normalization layer as a nn.Module.
facebook-github-bot's avatar
facebook-github-bot 已提交
138
139
140
141

    Returns:
        nn.Module or None: the normalization layer
    """
Yuxin Wu's avatar
Yuxin Wu 已提交
142
143
    if norm is None:
        return None
facebook-github-bot's avatar
facebook-github-bot 已提交
144
145
146
147
148
    if isinstance(norm, str):
        if len(norm) == 0:
            return None
        norm = {
            "BN": BatchNorm2d,
Yuxin Wu's avatar
Yuxin Wu 已提交
149
            # Fixed in https://github.com/pytorch/pytorch/pull/36382
Yuxin Wu's avatar
Yuxin Wu 已提交
150
            "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
facebook-github-bot's avatar
facebook-github-bot 已提交
151
152
            "FrozenBN": FrozenBatchNorm2d,
            "GN": lambda channels: nn.GroupNorm(32, channels),
Yuxin Wu's avatar
Yuxin Wu 已提交
153
154
155
            # for debugging:
            "nnSyncBN": nn.SyncBatchNorm,
            "naiveSyncBN": NaiveSyncBatchNorm,
facebook-github-bot's avatar
facebook-github-bot 已提交
156
157
158
159
160
161
        }[norm]
    return norm(out_channels)


class NaiveSyncBatchNorm(BatchNorm2d):
    """
Yuxin Wu's avatar
Yuxin Wu 已提交
162
    In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
Yuxin Wu's avatar
Yuxin Wu 已提交
163
    when the batch size on each worker is different.
facebook-github-bot's avatar
facebook-github-bot 已提交
164
165
    (e.g., when scale augmentation is used, or when it is applied to mask head).

Yuxin Wu's avatar
Yuxin Wu 已提交
166
    This is a slower but correct alternative to `nn.SyncBatchNorm`.
Yuxin Wu's avatar
Yuxin Wu 已提交
167
168
169
170
171
172
173

    Note:
        There isn't a single definition of Sync BatchNorm.

        When ``stats_mode==""``, this module computes overall statistics by using
        statistics of each worker with equal weight.  The result is true statistics
        of all samples (as if they are all on one worker) only when all workers
174
        have the same (N, H, W). This mode does not support inputs with zero batch size.
Yuxin Wu's avatar
Yuxin Wu 已提交
175
176
177
178
179
180
181
182
183
184
185

        When ``stats_mode=="N"``, this module computes overall statistics by weighting
        the statistics of each worker by their ``N``. The result is true statistics
        of all samples (as if they are all on one worker) only when all workers
        have the same (H, W). It is slower than ``stats_mode==""``.

        Even though the result of this module may not be the true statistics of all samples,
        it may still be reasonable because it might be preferrable to assign equal weights
        to all workers, regardless of their (H, W) dimension, instead of putting larger weight
        on larger images. From preliminary experiments, little difference is found between such
        a simplified implementation and an accurate computation of overall mean & variance.
facebook-github-bot's avatar
facebook-github-bot 已提交
186
187
    """

Yuxin Wu's avatar
Yuxin Wu 已提交
188
189
190
191
192
    def __init__(self, *args, stats_mode="", **kwargs):
        super().__init__(*args, **kwargs)
        assert stats_mode in ["", "N"]
        self._stats_mode = stats_mode

facebook-github-bot's avatar
facebook-github-bot 已提交
193
194
195
196
    def forward(self, input):
        if comm.get_world_size() == 1 or not self.training:
            return super().forward(input)

Yuxin Wu's avatar
Yuxin Wu 已提交
197
198
        B, C = input.shape[0], input.shape[1]

facebook-github-bot's avatar
facebook-github-bot 已提交
199
200
201
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

Yuxin Wu's avatar
Yuxin Wu 已提交
202
        if self._stats_mode == "":
203
            assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
Yuxin Wu's avatar
Yuxin Wu 已提交
204
            vec = torch.cat([mean, meansqr], dim=0)
205
            vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
Yuxin Wu's avatar
Yuxin Wu 已提交
206
207
208
209
210
            mean, meansqr = torch.split(vec, C)
            momentum = self.momentum
        else:
            if B == 0:
                vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
211
                vec = vec + input.sum()  # make sure there is gradient w.r.t input
Yuxin Wu's avatar
Yuxin Wu 已提交
212
213
214
215
            else:
                vec = torch.cat(
                    [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
                )
216
            vec = differentiable_all_reduce(vec * B)
Yuxin Wu's avatar
Yuxin Wu 已提交
217

218
219
220
            total_batch = vec[-1].detach()
            momentum = total_batch.clamp(max=1) * self.momentum  # no update if total_batch is 0
            total_batch = torch.max(total_batch, torch.ones_like(total_batch))  # avoid div-by-zero
Yuxin Wu's avatar
Yuxin Wu 已提交
221
            mean, meansqr, _ = torch.split(vec / total_batch, C)
facebook-github-bot's avatar
facebook-github-bot 已提交
222
223
224
225
226
227
228

        var = meansqr - mean * mean
        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
Yuxin Wu's avatar
Yuxin Wu 已提交
229
230
231

        self.running_mean += momentum * (mean.detach() - self.running_mean)
        self.running_var += momentum * (var.detach() - self.running_var)
facebook-github-bot's avatar
facebook-github-bot 已提交
232
        return input * scale + bias