Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks
BBatch Normalization Biases Residual BlocksTowards the Identity Function in Deep Networks
Soham De
DeepMind, London [email protected]
Samuel L. Smith
DeepMind, London [email protected]
Abstract
Batch normalization dramatically increases the largest trainable depth of residualnetworks, and this benefit has been crucial to the empirical success of deep residualnetworks on a wide range of benchmarks. We show that this key benefit arises be-cause, at initialization, batch normalization downscales the residual branch relativeto the skip connection, by a normalizing factor on the order of the square root ofthe network depth. This ensures that, early in training, the function computed bynormalized residual blocks in deep networks is close to the identity function (onaverage). We use this insight to develop a simple initialization scheme that cantrain deep residual networks without normalization. We also provide a detailed em-pirical study of residual networks, which clarifies that, although batch normalizednetworks can be trained with larger learning rates, this effect is only beneficial inspecific compute regimes, and has minimal benefits when the batch size is small.
The combination of skip connections [1–3] and batch normalization [4] dramatically increases thelargest trainable depth of neural networks. Although the origin of this effect is poorly understood, ithas led to a rapid improvement in the performance of deep networks on popular benchmarks [5, 6].Following the introduction of layer normalization [7] and the transformer architecture [8, 9], almostall state-of-the-art networks currently contain both skip connections and normalization layers.
Our contributions.
This paper provides a simple explanation for why batch normalized deep residualnetworks are easily trainable. We prove that batch normalization downscales the hidden activations onthe residual branch by a factor on the order of the square root of the network depth (at initialization).Therefore, as the depth of a residual network is increased, the residual blocks are increasinglydominated by the skip connection, which drives the functions computed by residual blocks closer tothe identity, preserving signal propagation and ensuring well-behaved gradients [10–15].If our theory is correct, it should be possible to train deep residual networks without normalization,simply by downscaling the residual branch. Therefore, to verify our analysis, we introduce a one-linecode change (“SkipInit”) which imposes this property at initialization, and we confirm that thisalternative scheme can train one thousand layer deep residual networks without normalization.In addition, we provide a detailed empirical study of residual networks at a wide range of batch sizes.This study demonstrates that, although batch normalization does enable us to train residual networkswith larger learning rates, we only benefit from using large learning rates in practice if the batch sizeis also large. When the batch size is small, both normalized and unnormalized networks have similaroptimal learning rates (which are typically much smaller than the largest stable learning rates) andyet normalized networks still achieve significantly higher test accuracies and lower training losses.These experiments demonstrate that, in residual networks, increasing the largest stable learning rateis not the primary benefit of batch normalization, contrary to the claims made in prior work [16, 17]. a r X i v : . [ c s . L G ] D ec aper layout. In section 2, we prove that residual blocks containing identity skip connections andnormalization layers are biased towards the identity function in deep networks (at initialization). Toconfirm that this property explains why deep normalized residual networks are trainable, we proposea simple alternative to normalization (“SkipInit”) that shares the same property at initialization, andwe provide an empirical study of normalized residual networks and SkipInit at a range of networkdepths. In section 3, we study the performance of residual networks at a range of batch sizes, in orderto clarify when normalized networks benefit from large learning rates. We study the regularizationbenefits of batch normalization in section 4 and we compare the performance of batch normalization,SkipInit and Fixup [18] on ImageNet in section 5. We discuss related work in section 6.
ReLUConvBN
A) B) + Conv α ReLU + Figure 1: A) A residual block with batch normal-ization. It is common practice to include two con-volutions on the residual branch; we show oneconvolution for simplicity. B) SkipInit replacesbatch normalization by a single learnable scalar α .We set α = 0 (or a small constant) at initialization.Residual networks (ResNets) [2, 3] contain asequence of residual blocks, which are com-posed of a “residual branch” comprising a num-ber of convolutions, normalization layers andnon-linearities, as well as a “skip connection”,which is usually just the identity (See figure1). While introducing skip connections short-ens the effective depth of the network, on theirown they only increase the trainable depth byroughly a factor of two [15]. Normalized resid-ual networks, on the other hand, can be trainedfor depths significantly deeper than twice thedepth of their non-residual counterparts [3, 18].To understand this effect, we analyze the vari-ance of hidden activations at initialization. Forclarity, we focus here on the variance of a single training example, but we discuss the variance acrossbatches of training examples (which share the same random weights) in appendix C. Let x (cid:96)i denotethe i -th component of the input to the (cid:96) -th residual block, where x denotes the input to the modelwith E ( x i ) = 0 and Var ( x i ) = 1 for each independent component i . Let f (cid:96) denote the functioncomputed by the residual branch of the (cid:96) -th residual block, x + i = max( x i , denote the outputof the ReLU, and B denote the batch normalization operation (for completeness, we define batchnormalization formally in appendix A). For simplicity, we assume that there is a single linear layer oneach residual branch, such that for normalized networks, f (cid:96) ( x (cid:96) ) = W (cid:96) B ( x (cid:96) ) + , and for unnormalizednetworks f (cid:96) ( x (cid:96) ) = W (cid:96) x (cid:96) + . We also assume that each component of W (cid:96) is independently sampledfrom N (0 , / fan-in ) (He Initialization) [19]. Thus, given x (cid:96) , the mean of the i -th coordinate of theoutput of a residual branch E (cid:0) f (cid:96)i ( x (cid:96) ) | x (cid:96) (cid:1) = 0 . Since x (cid:96) +1 = x (cid:96) + f (cid:96) ( x (cid:96) ) , this implies E (cid:0) x (cid:96)i (cid:1) = 0 for all (cid:96) . The covariance between the residual branch and the skip connection Cov ( f (cid:96)i ( x (cid:96) ) , x (cid:96)i ) = 0 ,and thus the variance of the hidden activations, Var ( x (cid:96) +1 i ) = Var ( x (cid:96)i ) + Var ( f (cid:96)i ( x (cid:96) )) . We conclude: Unnormalized networks:
If the residual branch is unnormalized, the variance of the residualbranch,
Var ( f (cid:96)i ( x (cid:96) )) = (cid:80) fan-in j Var ( W (cid:96)ij ) · E (( x (cid:96) + j ) ) = 2 · E (( x (cid:96) + i ) ) = Var ( x (cid:96)i ) . This hastwo implications. First, the variance of the hidden activations explode exponentially with depth, Var ( x (cid:96) +1 i ) = 2 · Var ( x (cid:96)i ) = 2 (cid:96) . One can prevent this explosion by introducing a factor of (1 / √ atthe end of each residual block, such that x (cid:96) +1 = ( x (cid:96) + f (cid:96) ( x (cid:96) )) / √ . Second, since Var ( f (cid:96)i ( x (cid:96) )) = Var ( x (cid:96)i ) , the residual branch and the skip connection contribute equally to the output of the residualblock. This ensures that the function computed by the residual block is far from the identity function. Normalized networks:
If the residual branch is normalized, the variance of the output of theresidual branch
Var ( f (cid:96)i ( x (cid:96) )) = (cid:80) fan-in j Var ( W (cid:96)ij ) · E (( B ( x (cid:96) ) + j ) ) = Var ( B ( x (cid:96) ) i ) ≈ . Thus, thevariance of the input to the (cid:96) -th residual block,
Var ( x (cid:96)i ) ≈ Var ( x (cid:96) − i )+1 , which implies Var ( x (cid:96)i ) ≈ (cid:96) . fan-in denotes the number of incoming network connections to the layer. The approximation is tight when the batch size for computing the batch statistics is large. a) (b) (c) Figure 2: We empirically evaluate the dependence of the variance of the hidden activations on thedepth of the residual block at initialization (See appendix B for details). In (a), we consider a fullyconnected ResNet with linear activations without any normalization, evaluated on random Gaussianinputs. In (b), we consider the same ResNet but with one normalization layer on each residual branch.The squared BatchNorm moving mean is close to zero (not shown). In (c), we consider a batchnormalized convolutional residual network with ReLU activations, evaluated on CIFAR-10.Surprisingly, the growth in the variance of the hidden activations is beneficial, because if
Var ( x (cid:96)i ) ≈ (cid:96) ,then the batch normalization operation B must suppress the variance of the (cid:96) -th residual branch by afactor of (cid:96) (hidden activations are suppressed by √ (cid:96) ) . Consequently, the residual branch contributesonly a / ( (cid:96) + 1) fraction of the variance in the output of the (cid:96) -th residual block. This ensures that, atinitialization, the outputs of most residual blocks in a deep normalized ResNet are dominated by theskip connection, which biases the function computed by the residual block towards the identity.The depth of a typical residual block is proportional to the total number of residual blocks d , whichimplies that batch normalization downscales residual branches by a factor on the order of √ d .Although this is weaker than the factor of d proposed in [12], we find empirically in section 2.3 thatit is sufficiently strong to train deep residual networks with 1000 layers. We emphasize that while ouranalysis only explicitly considers the propagation of the signal on the forward pass, residual blocksdominated by the skip path on the forward pass will also preserve signal propagation on the backwardpass. This is because, when the forward signal on the (cid:96) -th residual branch is downscaled by a factor α , the backward propagated signal through that branch will also be downscaled by a factor α [20].To verify our analysis, we evaluate the variance of the hidden activations, as well as the batchnormalization statistics, of three residual networks at initialization in figure 2. We define the networksin appendix B. In figure 2(a), we consider a fully connected linear unnormalized residual network,where we find that the variance on the skip path of the (cid:96) -th residual block matches the variance ofthe residual branch and is equal to (cid:96) − , as predicted by our analysis. In figure 2(b), we consider afully connected linear normalized residual network, where we find that the variance on the skip pathof the (cid:96) -th residual block is approximately equal to (cid:96) , while the variance at the end of each residualbranch is approximately 1. The batch normalization moving variance on the (cid:96) -th residual block isalso approximately equal to (cid:96) , confirming that batch normalization downscales the residual branch bya factor of √ (cid:96) as predicted. In figure 2(c), we consider a normalized convolutional residual networkwith ReLU activations evaluated on CIFAR-10. The variance on the skip path remains proportionalto the depth of the residual block, with a coefficient slightly below 1 (likely due to zero padding atthe image boundary). The batch normalization moving variance is also proportional to depth, butslightly smaller than the variance across channels on the skip path. We show in appendix C that thisoccurs because ReLU activations introduce correlations between different examples in the mini-batch.These correlations also cause the square of the batch normalization moving mean to grow with depth. We claim above that batch normalization enables us to train deep residual networks, because (inexpectation) it downscales the residual branch at initialization by a normalizing factor on the order ofthe square root of the network depth. To provide further evidence for this claim, we now propose asimple initialization scheme that can train deep residual networks without normalization, “SkipInit”:
SkipInit: Include a learnable scalar multiplier at the end of each residual branch, initialized to α . After normalization is removed, it should be possible to implement SkipInit as a one line code change.In section 2.3, we show that we can train deep residual networks, so long as α is initialized at a value3able 1: Batch normalization enables us to train deep residual networks. We can recover this benefitwithout normalization if we introduce a scalar multiplier α on the end of the residual branch andinitialize α = (1 / √ d ) or smaller (where d is the number of residual blocks). In practice, we advocateinitializing α = 0 . We provide optimal test accuracies and optimal learning rates with error bars. Notethat we do not provide results in cases where the test accuracy was frozen at random initializationthroughout training for all learning rates in the range − to (i.e., in cases where training failed). Batch NormalizationDepth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − ) SkipInit ( α = 1 / √ d )Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − ) SkipInit ( α = 0 )Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − ) SkipInit ( α = 1 )Depth Test accuracy Learning rate . ± . − ( − to − ) − − − − Divide residual block by √ Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 − − SkipInit without L2 ( α = 0 )Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − )of (1 / √ d ) or smaller, where d denotes the total number of residual blocks (see table 1). Notice thatthis observation agrees exactly with our analysis of deep normalized residual networks in section 2.1.In practice, we recommend setting α = 0 , so that the residual block represents the identity functionat initialization. This choice is also simpler to apply, since it ensures the initialization scheme isindependent of network depth. We note that SkipInit is designed for residual networks that contain anidentity skip connection such as the ResNet-V2 [3] or Wide-ResNet architectures [21]. We discusshow to extend SkipInit to the original ResNet-V1 [2] formulation of residual networks in appendix F. We empirically verify the claims made above by studying the minimal components required to traindeep residual networks. In table 1, we report the mean test accuracy of an n - Wide-ResNet [21],trained on CIFAR-10 for 200 epochs at batch size 64 at a range of depths n between 16 and 1000layers. At each depth, we train the network 7 times for a range of learning rates on a logarithmicgrid, and we measure the mean and standard deviation of the test accuracy for the best 5 runs (thisprocedure ensures that our results are not corrupted by outliers or failed runs). The optimal testaccuracy is the mean performance at the learning rate whose mean test accuracy was highest, andwe always verify that the optimal learning rates are not at the boundary of our grid search. Here andthroughout this paper, we use SGD with heavy ball momentum, and fix the momentum coefficient m = 0 . . Although we tune the learning rate on the test set, we emphasize that our goal is notto achieve state of the art results. Our goal is to compare the performance of different trainingprocedures, and we apply the same experimental protocol in each case. We hold the learning rateconstant for 100 epochs, before dropping the learning rate by a factor of 2 every 10 epochs. Thissimple schedule achieves higher test accuracies than the original 3 drops schedule proposed in [2]. Weapply data augmentation including per-image standardization, padding, random crops and left-rightflips. We use L2 regularization with a coefficient of × − , and we initialize convolutional layersusing He initialization [19]. We provide the corresponding optimal training losses in appendix D.As expected, batch normalized Wide-ResNets are trainable for a wide range of depths, and the optimallearning rate is only weakly dependent on the depth. We can recover this effect without normalizationby incorporating SkipInit and initializing α = (1 / √ d ) or smaller, where d denotes the number ofresidual blocks. This provides strong evidence to support our claim that batch normalization enablesus to train deep residual networks by biasing residual blocks towards the skip path at initialization.4 a) (b)(c) (for test accuracy) (d) (for training loss) Figure 3: In (a), we achieve higher test accuracies with batch normalization than without batchnormalization, and we are also able to train efficiently at much larger batch sizes. SkipInit substantiallyreduces the gap in performance for small/moderate batch sizes, but it still under-performs batchnormalization when the batch size is large. In (b), SkipInit achieves smaller training losses than batchnormalization for batch sizes b (cid:46) . We provide the test accuracy at the learning rate for whichthe test accuracy was maximized, and the training loss at the learning rate for which the training losswas minimized. To help interpret these results, we also provide the optimal learning rates in figures(c) and (d). When the batch size is small, all three methods have similar optimal learning rates (whichare much smaller than the maximum stable learning rate for each method), but batch normalizationand SkipInit are able to scale to larger learning rates when the batch size is large.Just like normalized networks, the optimal learning rate with SkipInit is almost independent of thenetwork depth. SkipInit slightly under-performs batch normalization on the test set at all depths,although we show in appendix D that it achieves similar training losses to normalized networks.For completeness, we verify in table 1 that one cannot train deep residual networks with SkipInit if α = 1 . We also show that for unnormalized residual networks, it is not sufficient merely to ensurethe activations do not explode on the forward pass (which can be achieved by multiplying the outputof each residual block by (1 / √ ). This confirms that ensuring stable forward propagation of thesignal is not sufficient for trainability. Additionally, we noticed that, at initialization, the loss in deepnetworks is dominated by the L2 regularization term, causing the weights to shrink rapidly earlyin training. To clarify whether this effect is necessary, we evaluated SkipInit ( α = 0 ) without L2regularization, and find that L2 regularization is not necessary for trainability. This demonstratesthat we can train deep residual networks without normalization and without reducing the scale ofthe weights at initialization, solely by downscaling the hidden activations on the residual branch. Tofurther test the theory that downscaling the residual branch is the key benefit of batch normalizationin deep ResNets, we tried several other variations of batch-normalized ResNets, which we presentin appendix D. We find that variants of batch-normalized ResNets which do not downscale theresidual branch relative to the skip path are not trainable for large depths (e.g. networks that placenormalization layers on the skip path). We provide additional results on CIFAR-100 in appendix E. In two widely read papers, Santurkar et al. [16] and Bjorck et al. [17] argued that the primary benefitof batch normalization is that it improves the conditioning of the loss landscape, which allows usto train stably with larger learning rates. However, this claim seems incompatible with a number ofrecent papers studying optimization in deep learning [22–29]. These papers argue that if we train fora fixed number of epochs (as is common in practice), then when the batch size is small, the optimal5earning rate is significantly smaller than the largest stable learning rate, since it is constrained bythe noise in the gradient estimate. In this small batch regime, the optimal learning rate is usuallyproportional to the batch size [29–31]. Meanwhile the conditioning of the loss sets the maximumstable learning rate [26–29], and this controls how large we can make the batch size before theperformance of the model begins to degrade under a fixed epoch budget. If this perspective is correct,we would expect large stable learning rates to be beneficial only when the batch size is also large. Inthis section, we clarify the role of large learning rates in normalized networks by studying residualnetworks with and without batch normalization at a wide range of batch sizes.In figure 3, we provide results for a 16-4 Wide-ResNet, trained on CIFAR-10 for 200 epochs at awide range of batch sizes and learning rates. We follow the same experimental protocol described insection 2.3, however we average over the best 12 out of 15 runs. To enable us to consider extremelylarge batch sizes on a single GPU, we evaluate the batch statistics over a “ghost batch size” of 64,before accumulating gradients to form larger batches, as is standard practice [32]. We therefore areunable to consider batch sizes below 64 with batch normalization. Note that we repeat this experimentin the small batch limit in section 4, where we evaluate the batch statistics over the full training batch.Unsurprisingly, the performance with batch normalization is better than the performance withoutbatch normalization on both the test set and the training set at all batch sizes. However, both with andwithout batch normalization, the optimal test accuracy is independent of batch size in the small batchlimit, before beginning to decrease when the batch size exceeds some critical threshold. Crucially,this threshold is significantly larger when batch normalization is used, which demonstrates that onecan efficiently scale training to larger batch sizes in normalized networks. SkipInit reduces the gap intest accuracy between normalized and unnormalized networks, and it achieves smaller training lossesthan batch normalization when the batch size is small ( b (cid:46) ). However similar to unnormalizednetworks, it still performs worse than normalized networks when the batch size is very large.To explain why normalized networks can scale training to larger batch sizes, we provide the optimallearning rates that maximize the test accuracy and minimize the training loss in figures 3(c) and 3(d).When the batch size is small, the optimal learning rates for all three methods are proportional tothe batch size and are similar to each other. Crucially, the optimal learning rates are much smallerthan the largest stable learning rate for each method. On the other hand, when the batch size islarge, the optimal learning rates are independent of batch size [26, 27], and normalized networks uselarger learning rates. Intuitively, this transition occurs when we reach the maximum stable learningrate, above which training diverges [28]. Our results confirm that batch normalized networks have alarger maximum stable learning rate than SkipInit networks, which have a larger maximum stablelearning rate than unnormalized networks. This explains why batch normalized networks were able toefficiently scale training to larger batch sizes. Crucially however, our experiments confirm that batchnormalized networks do not benefit from the use of large learning rates when the batch size is small.Furthermore, under a fixed epoch budget, the highest test accuracies for all three methods are alwaysachieved in the small batch limit with small learning rates, and the test accuracy never increaseswhen the batch size rises. We therefore conclude that large learning rates are not the primary benefitof batch normalization in residual networks, contradicting the claims of earlier work [16, 17]. Theprimary benefit of batch normalization is that it biases the residual blocks in deep residual networkstowards the identity function, thus enabling us to train significantly deeper networks. To emphasizethis claim, we show in the next section that the gap in test accuracy between batch normalization andSkipInit in the small batch limit can be further reduced with additional regularization. We provideadditional results sweeping the batch size on a 28-10 Wide-ResNet on CIFAR-100 in appendix E. It is widely known that batch normalization can have a regularizing effect [32]. Most authors believethat this benefit arises from the noise that arises when the batch statistics are estimated on a subsetof the full training set [33]. In this section, we study this regularization benefit at a range of batchsizes. Unlike the previous section (which used a “ghost batch size” of 64 [32]), in this section we will Note that we plot the training loss excluding the L2 regularization term in figure 3. Normalized networksoften achieve smaller L2 losses because the network function is independent of the scale of the weights. As the batch size grows, the number of parameter updates decreases since the number of training epochs isfixed. We note that the performance might not degrade with batch size under a constant step budget [25]. b ≥ . The training loss falls as the batchsize increases, but the test accuracy is maximized for an intermediate batch size, b ≈ . RegularizedSkipInit outperforms batch normalization on the test set for small batch sizes.evaluate the batch statistics of normalized networks over the entire mini-batch. We introduced SkipInitin section 2.2, which ensures that very deep unnormalized ResNets are trainable. To attempt torecover the additional regularization benefits of batch normalization, we now introduce “RegularizedSkipInit”. This scheme includes SkipInit ( α = 0 ), but also introduces biases to all convolutions andapplies a single Dropout layer [34] before the softmax (We use drop probability 0.6 in this section).In figure 4, we provide the performance of our 16-4 Wide-ResNet at a range of batch sizes in thesmall batch limit (note that batch normalization reduces to instance normalization when the batchsize b = 1 ). We provide the corresponding optimal learning rates in appendix D. The test accuracy ofbatch normalized networks initially improves as the batch size rises, before decaying for batch sizes b (cid:38) . Meanwhile, the training loss increases as the batch size rises from 1 to 2, but then decreasesconsistently as the batch size rises further. This confirms that the uncertainty in the estimate of thebatch statistics does have a generalization benefit if properly tuned (This is also why we chose a ghostbatch size of 64 in section 3). The performance of SkipInit and Regularized SkipInit are independentof batch size in the small batch limit, and Regularized SkipInit achieves higher test accuracies thanbatch normalization when the batch size is very small. Note that we introduced Dropout [34] to showthat extra regularization may be necessary to close the performance gap between normalized andSkipInit networks, but more sophisticated regularizers would likely achieve higher test accuracies.We provide additional results studying this regularization effect on CIFAR-100 in appendix E. In this section, we compare the performance of batch normalization and SkipInit on ImageNet. Forcompleteness, we also compare to the recently proposed Fixup initialization [18]. Since SkipInitis designed for residual networks with an identity skip connection, we consider the ResNet50-V2architecture [3]. We provide additional experiments on ResNet50-V1 [2] in appendix F. We usethe original architectures and match the performance reported by [35] (we do not apply the popularmodifications to these architectures described in [22]). We train for 90 epochs, and when batchnormalization is used we set the ghost batch size to 256. The learning rate is linearly increased from0 to the specified value over the first 5 epochs of training [22], and then held constant for 40 epochs,before decaying it by a factor of 2 every 5 epochs. As before, we tune the learning rate at eachbatch size on a logarithmic grid. We provide the optimal validation accuracies in table 2. We foundthat adding biases to the convolutional layers led to a small boost in accuracy for SkipInit, and wetherefore included biases in all SkipInit runs. SkipInit and Fixup match the performance of batchnormalization at the standard batch size of 256, however both SkipInit and Fixup perform worse thanbatch normalization when the batch size is very large. Both SkipInit and Fixup achieve higher testaccuracies than batch normalization with extra regularization (Dropout) for small batch sizes. Weinclude code for our Tensorflow [36] implementation of ResNet50-V2 with SkipInit in appendix G.
In recent years, almost all state-of-the-art models have involved applying some kind of normalizationscheme [4, 7, 37–39] in combination with skip connections [1–3, 8, 9]. Although some authors havesucceeded in training very deep networks without normalization layers or skip connections [14, 40],7able 2: When training ResNet50-V2 on ImageNet, SkipInit and Fixup are competitive with batchnormalization for small batch sizes, while batch normalization performs best when the batch size islarge. SkipInit and Fixup both achieve higher validation accuracies than batch normalization withextra regularization. We train for 90 epochs and perform a grid search to identify the optimal learningrate which maximizes the top-1 validation accuracy. We perform a single run at each learning rateand report top-1 and top-5 accuracy scores. We use a drop probability of . when Dropout is used. Batch sizeTest accuracy:
256 1024 4096Batch normalization 75.0 / 92.2 74.9 / 92.1 74.9 / 91.9Fixup 74.8 / 91.8 74.6 / 91.7 73.0 / 90.6SkipInit + Biases 74.9 / 91.9 74.6 / 91.8 70.8 / 89.2Fixup + Dropout 75.8 / 92.5 75.6 / 92.5 74.8 / 91.8Regularized SkipInit 75.6 / 92.4 75.5 / 92.5 72.7 / 90.7these papers required careful orthogonal initialization schemes that are not compatible with ReLUactivation functions. Balduzzi et al. [11] and Yang et al. [13] argued that ResNets with identity skipconnections and batch normalization layers on the residual branch preserve correlations betweendifferent minibatches in deep networks, and Balduzzi et al. [11] suggested that this effect can bemimicked by initializing deep networks close to linear functions. However, even deep linear networksare difficult to train with Gaussian weights [12, 15, 40], which suggests that imposing linearity atinitialization is not sufficient. Veit et al. [10] observed empirically that normalized residual networksare typically dominated by short paths, however they did not identify the cause of this effect. Someauthors have studied initialization schemes which multiply the output of the residual branch by afixed scalar (smaller than 1), without establishing a link to normalization methods [11, 12, 41–44].Santurkar et al. [16] and Bjorck et al. [17] argued that batch normalization improves the conditioningof the loss landscape, which enables us to train with larger learning rates and converge in fewerparameter updates. Arora et al. [45] argued that batch normalization reduces the importance of tuningthe learning rate, while Li and Arora [46] showed that models trained using batch normalization canconverge even if the learning rate increases exponentially during training. A similar analysis alsoappears in [47], while Luo et al. [33] analyzed the regularization benefits of batch normalization.Zhang et al. [18] proposed Fixup initialization, and confirmed that it can train both deep residualnetworks and deep transformers without normalization layers. Fixup contains four components:1. The classification layer and final convolution of each residual branch are initialized to zero.2. The initial weights of the remaining convolutions are scaled down by d − / (2 m − , where d denotes the number of residual branches and m is the number of convolutions per branch.3. A scalar multiplier is introduced at the end of each residual branch, intialized to one.4. Scalar biases are introduced before every layer in the network, initialized to zero.The authors do not relate these components to the influence of the batch normalization layers on theresidual branch, or seek to explain why deep normalized ResNets are trainable. They argue that thesecond component of Fixup is essential, however our experiments in section 2.3 demonstrate that thiscomponent is not necessary to train deep residual networks at typical batch sizes. In practice, wehave found that either component 1 or component 2 of Fixup on its own is sufficient in ResNet-V2networks, since both components downscale the hidden activations on the residual branch (fulfillingthe same role as SkipInit). We found in section 5 that SkipInit and Fixup have similar performancefor small batch sizes but that Fixup slightly outperforms SkipInit when the batch size is large. Our work demonstrates that batch normalization has three main benefits. In order of importance,1. Batch normalization can train deep residual networks (section 2).2. Batch normalization increases the maximum stable learning rate (section 3).3. Batch normalization has a regularizing effect (section 4).8his work explains benefit 1, by observing that batch normalization biases residual blocks towardsthe identity function at initialization. This ensures that deep residual networks have well-behavedgradients, enabling efficient training [10–15]. Furthermore, our argument naturally extends to othernormalization variants and model architectures, including layer normalization [7] and “pre-norm”transformers [9] (where the normalization layers are on the residual branch). A single normalizationlayer per residual branch is sufficient, and normalization layers should not be placed on the skip path(as in the original transformer [8]). We can recover benefit 1 without normalization by introducing alearnable scalar multiplier on the residual branch initialized to zero. This simple change can traindeep ResNets without normalization, and often enhances the performance of shallow ResNets.The conditioning benefit (benefit 2) is not necessary when one trains with small batch sizes, but itremains beneficial when one wishes to train with large batch sizes. Since large batch sizes can becomputed in parallel across multiple devices [22], this could make normalization necessary in time-critical situations, for instance if a production model is retrained frequently in response to changinguser preferences. Also, since batch normalization has a regularizing effect (benefit 3), it may benecessary in some architectures if one wishes to achieve the highest possible test accuracy. Notehowever that one can sometimes exceed the test accuracy of normalized networks by introducingalternate regularizers (see section 5 or [18]). We therefore believe future work should focus onidentifying an alternative to batch normalization that recovers its conditioning benefits.We would like to comment briefly on the similarity between SkipInit for residual networks, andOrthogonal initialization of vanilla fully connected tanh networks [40]. Orthogonal initialization iscurrently the only initialization scheme capable of training deep networks without skip connections.It initializes the weights of each layer as an orthogonal matrix, such that the activations after alinear layer are a rotation (or reflection) of the activations before the layer. Meanwhile, the tanhnon-linearity is approximately equal to the identity for small activations over a region of scale 1around the origin. Intuitively, if the incoming activations are mean centered with scale 1, they willpass through the non-linearity almost unchanged. Since rotations compose, the approximate action ofthe entire network at initialization is to rotate (or reflect) the input. Like residual blocks with SkipInit,the influence of a fully connected layer with orthogonal weights will therefore be close to the identityin function space. However ReLUs are not compatible with orthogonal initialization, since they arenot linear about the origin, which has limited the use of orthogonal initialization in practice.
To conclude.
Batch normalization biases the residual blocks of deep residual networks towards theidentity function (at initialization). This ensures that the network has well behaved-gradients, and it istherefore a major factor behind the excellent empirical performance of normalized residual networksin practice. We show that one can recover this benefit in unnormalized residual networks with aone line code change to the architecture (“SkipInit”). In addition, we clarify that, although batchnormalized networks can be trained with larger learning rates than unnormalized networks, this isonly useful for large batch sizes and does not have practical benefits when the batch size is small.
Broader impact
This work seeks to develop fundamental understanding by identifying the benefits batch normalizationbrings when training residual networks. We do not foresee any specific negative consequences of thiswork, although we hope that fundamental understanding may help drive future progress in the field.
Funding disclosure
All authors are employees of DeepMind, which was the sole source of funding for this work. Noauthors have any competing interests.
Acknowledgements
We thank Jeff Donahue, Chris Maddison, Erich Elsen, James Martens, Razvan Pascanu, Chongli Qin,Karen Simonyan, Yann Dauphin, Esme Sutherland and Yee Whye Teh for various discussions thathave helped improve the paper. 9 eferences [1] Rupesh Kumar Srivastava, Klaus Greff, and Jürgen Schmidhuber. Highway networks. arXivpreprint arXiv:1505.00387 , 2015.[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for imagerecognition. In
Proceedings of the IEEE conference on computer vision and pattern recognition ,pages 770–778, 2016.[3] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residualnetworks. In
European conference on computer vision , pages 630–645. Springer, 2016.[4] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network trainingby reducing internal covariate shift. arXiv preprint arXiv:1502.03167 , 2015.[5] Mingxing Tan and Quoc V Le. Efficientnet: Rethinking model scaling for convolutional neuralnetworks. arXiv preprint arXiv:1905.11946 , 2019.[6] Qizhe Xie, Eduard Hovy, Minh-Thang Luong, and Quoc V Le. Self-training with noisy studentimproves imagenet classification. arXiv preprint arXiv:1911.04252 , 2019.[7] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprintarXiv:1607.06450 , 2016.[8] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In
Advances in neural informationprocessing systems , pages 5998–6008, 2017.[9] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever.Language models are unsupervised multitask learners.
OpenAI Blog , 1(8):9, 2019.[10] Andreas Veit, Michael J Wilber, and Serge Belongie. Residual networks behave like ensemblesof relatively shallow networks. In
Advances in neural information processing systems , pages550–558, 2016.[11] David Balduzzi, Marcus Frean, Lennox Leary, JP Lewis, Kurt Wan-Duo Ma, and BrianMcWilliams. The shattered gradients problem: If resnets are the answer, then what is thequestion? In
Proceedings of the 34th International Conference on Machine Learning-Volume70 , pages 342–350. JMLR. org, 2017.[12] Boris Hanin and David Rolnick. How to start training: The effect of initialization and architec-ture. In
Advances in Neural Information Processing Systems , pages 571–581, 2018.[13] Greg Yang, Jeffrey Pennington, Vinay Rao, Jascha Sohl-Dickstein, and Samuel S Schoenholz.A mean field theory of batch normalization. arXiv preprint arXiv:1902.08129 , 2019.[14] Lechao Xiao, Yasaman Bahri, Jascha Sohl-Dickstein, Samuel S Schoenholz, and Jeffrey Pen-nington. Dynamical isometry and a mean field theory of cnns: How to train 10,000-layer vanillaconvolutional neural networks. arXiv preprint arXiv:1806.05393 , 2018.[15] Karthik A Sankararaman, Soham De, Zheng Xu, W Ronny Huang, and Tom Goldstein. Theimpact of neural network overparameterization on gradient confusion and stochastic gradientdescent. arXiv preprint arXiv:1904.06963 , 2019.[16] Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry. How does batchnormalization help optimization? In
Advances in Neural Information Processing Systems , pages2483–2493, 2018.[17] Nils Bjorck, Carla P Gomes, Bart Selman, and Kilian Q Weinberger. Understanding batchnormalization. In
Advances in Neural Information Processing Systems , pages 7694–7705, 2018.[18] Hongyi Zhang, Yann N Dauphin, and Tengyu Ma. Fixup initialization: Residual learningwithout normalization. arXiv preprint arXiv:1901.09321 , 2019.1019] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers:Surpassing human-level performance on imagenet classification. In
Proceedings of the IEEEinternational conference on computer vision , pages 1026–1034, 2015.[20] Masato Taki. Deep residual networks and weight initialization. arXiv preprint arXiv:1709.02956 ,2017.[21] Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprintarXiv:1605.07146 , 2016.[22] Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola,Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Trainingimagenet in 1 hour. arXiv preprint arXiv:1706.02677 , 2017.[23] Samuel L Smith, Pieter-Jan Kindermans, Chris Ying, and Quoc V Le. Don’t decay the learningrate, increase the batch size. arXiv preprint arXiv:1711.00489 , 2017.[24] Stanisław Jastrz˛ebski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, YoshuaBengio, and Amos Storkey. Three factors influencing minima in sgd. arXiv preprintarXiv:1711.04623 , 2017.[25] Christopher J Shallue, Jaehoon Lee, Joe Antognini, Jascha Sohl-Dickstein, Roy Frostig, andGeorge E Dahl. Measuring the effects of data parallelism on neural network training. arXivpreprint arXiv:1811.03600 , 2018.[26] Sam McCandlish, Jared Kaplan, Dario Amodei, and OpenAI Dota Team. An empirical modelof large-batch training. arXiv preprint arXiv:1812.06162 , 2018.[27] Guodong Zhang, Lala Li, Zachary Nado, James Martens, Sushant Sachdeva, George E Dahl,Christopher J Shallue, and Roger Grosse. Which algorithmic choices matter at which batchsizes? insights from a noisy quadratic model. arXiv preprint arXiv:1907.04164 , 2019.[28] Siyuan Ma, Raef Bassily, and Mikhail Belkin. The power of interpolation: Understanding theeffectiveness of sgd in modern over-parametrized learning. arXiv preprint arXiv:1712.06559 ,2017.[29] Samuel L Smith, Erich Elsen, and Soham De. On the generalization benefit of noise in stochasticgradient descent. arXiv preprint arXiv:2006.15081 , 2020.[30] Stephan Mandt, Matthew D Hoffman, and David M Blei. Stochastic gradient descent asapproximate bayesian inference.
The Journal of Machine Learning Research , 18(1):4873–4907,2017.[31] Samuel L Smith and Quoc V Le. A bayesian perspective on generalization and stochasticgradient descent. arXiv preprint arXiv:1710.06451 , 2017.[32] Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing thegeneralization gap in large batch training of neural networks. In
Advances in Neural InformationProcessing Systems , pages 1731–1741, 2017.[33] Ping Luo, Xinjiang Wang, Wenqi Shao, and Zhanglin Peng. Towards understanding regu-larization in batch normalization. In
International Conference on Learning Representations ,2019.[34] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov.Dropout: a simple way to prevent neural networks from overfitting.
The journal of machinelearning research , 15(1):1929–1958, 2014.[35] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual networks githubpage: https://github.com/kaiminghe/deep-residual-networks, 2016.[36] Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, MatthieuDevin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: A system forlarge-scale machine learning. In { USENIX } symposium on operating systems design andimplementation ( { OSDI } , pages 265–283, 2016.1137] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization: The missingingredient for fast stylization. arXiv preprint arXiv:1607.08022 , 2016.[38] Tim Salimans and Durk P Kingma. Weight normalization: A simple reparameterization toaccelerate training of deep neural networks. In Advances in Neural Information ProcessingSystems , pages 901–909, 2016.[39] Yuxin Wu and Kaiming He. Group normalization. In
Proceedings of the European Conferenceon Computer Vision (ECCV) , pages 3–19, 2018.[40] Andrew M Saxe, James L McClelland, and Surya Ganguli. Exact solutions to the nonlineardynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120 , 2013.[41] Devansh Arpit, Víctor Campos, and Yoshua Bengio. How to initialize your network? robustinitialization for weightnorm & resnets. In
Advances in Neural Information Processing Systems ,pages 10900–10909, 2019.[42] Huishuai Zhang, Da Yu, Mingyang Yi, Wei Chen, and Tie-Yan Liu. Convergence theory oflearning over-parameterized resnet: A full characterization, 2019.[43] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning viaover-parameterization, 2018.[44] Simon S. Du, Jason D. Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent findsglobal minima of deep neural networks, 2018.[45] Sanjeev Arora, Zhiyuan Li, and Kaifeng Lyu. Theoretical analysis of auto rate-tuning by batchnormalization. arXiv preprint arXiv:1812.03981 , 2018.[46] Zhiyuan Li and Sanjeev Arora. An exponential learning rate schedule for deep learning. arXivpreprint arXiv:1910.07454 , 2019.[47] Yongqiang Cai, Qianxiao Li, and Zuowei Shen. A quantitative analysis of the effect of batchnormalization on gradient descent. In
International Conference on Machine Learning , pages882–890. PMLR, 2019.[48] Yann A LeCun, Léon Bottou, Genevieve B Orr, and Klaus-Robert Müller. Efficient backprop.In
Neural networks: Tricks of the trade , pages 9–48. Springer, 2012.[49] Malcolm Reynolds, Gabriel Barth-Maron, Frederic Besse, Diego de Las Casas, AndreasFidjeland, Tim Green, Adrià Puigdomènech, Sébastien Racanière, Jack Rae, and FabioViola. Open sourcing Sonnet - a new library for constructing neural networks. https://deepmind.com/blog/open-sourcing-sonnet/ , 2017.12
Definition of a batch normalization layer
When applying batch normalization to convolutional layers, the inputs and outputs of normalizationlayers are 4-dimensional tensors, which we denote by I b,x,y,c and O b,x,y,c . Here b denotes the batchdimension, c denotes the channels, and x and y are the two spatial dimensions. Batch normalization[4] applies the same normalization to every input in the same channel, such that: O b,x,y,c = γ c I b,x,y,c − µ c (cid:112) σ c + (cid:15) + β c . Here, µ c = Z (cid:80) b,x,y I b,x,y,c denotes the per-channel mean, and σ c = Z (cid:80) b,x,y I b,x,y,c − µ c denotesthe per-channel variance, and Z is the normalization constant summed over the minibatch b andspatial dimensions x and y . A small constant (cid:15) is included in the denominator for numerical stability.The “scale” and “shift” parameters, γ c and β c , are learned during training. Typically, γ c is initializedto 1 and β c is initialized to 0, which is also what we consider in our analysis. Running averages of µ c and σ c are also maintained during training, and these averages are used at test time to ensure thepredictions are independent of other examples in the batch. For distributed training, the batch statisticsare usually estimated locally on a subset of the training minibatch (“ghost batch normalization” [32]). B Details of the residual networks used for figure 2
In figure 2 of the main text, we studied the variance of hidden activations and the batch statistics ofresidual blocks at a range of depths in three different architectures; a deep linear fully connectedunnormalized residual network, a deep linear fully connected normalized residual network and a deepconvolutional normalized residual network with ReLUs. We now define the three models in full.
Deep fully connected linear residual network without normalization:
The inputs are 100 dimen-sional vectors composed of independent random samples from the unit normal distribution, and thebatch size is 1000. These inputs first pass through a single fully connected linear layer of width 1000.We then apply a series of residual blocks. Each block contains an identity skip path, and a residualbranch composed of a fully connected linear layer of width 1000. All linear layers are initialized withLeCun normal initialization [48] to preserve the variance in the absence of non-linearities.
Deep fully connected linear residual network with batch normalization:
The inputs are 100dimensional vectors composed of independent random samples from the unit normal distribution,and the batch size is 1000. These inputs first pass through a batch normalization layer and a singlefully connected linear layer of width 1000. We then apply a series of residual blocks. Each blockcontains an identity skip path, and a residual branch composed of a batch normalization layer anda fully connected linear layer of width 1000. All linear layers are initialized with LeCun normalinitialization [48] to preserve the variance in the absence of non-linearities.
Deep convolutional ReLU residual network:
The inputs are batches of 100 images from theCIFAR-10 training set. We first apply a convolution of width 100 and stride 2, followed by a batchnormalization layer, a ReLU non-linearity, and an additional convolution of width 100 and stride 2.We then apply a series of residual blocks. Each block contains an identity skip path, and a residualbranch composed of a batch normalization layer, a ReLU non-linearity, and a convolution of width100 and stride 1. All convolutions are initialized with He initialization [19].In all three networks, we evaluate the variance at initialization on the skip path and at the end of theresidual branch (we measure the empirical variance across multiple channels and multiple examplesbut for a single set of weights). For the two normalized networks, we also evaluate the mean movingvariance and mean squared moving mean of the batch normalization layer (i.e., the mean value ofthe moving variance parameter and the mean value of the square of the moving mean, averagedover channels for a single set of weights). To obtain the batch normalization statistics, we set themomentum parameter of the batch normalization layers to 0, and then update the batch statistics once.
C The influence of ReLU non-linearities on batch normalization statistics
In the main text, we found that for the deep linear normalized residual network (figure 2(b)), thevariance on the skip path is equal to the mean moving variance of the batch normalization layer, while13igure 5: The batch statistics at initialization of a normalized deep fully connected network withReLU non-linearities, evaluated on random inputs drawn from a Gaussian distribution.the mean squared moving mean of the batch normalization layer is close to zero. However whenwe introduce ReLU non-linearities in the deep normalized convolutional residual network (figure2(c)), the mean moving variance of the batch normalization layer is smaller than the variance acrosschannels on the skip path, and the mean squared moving mean of the normalization layer growsproportional to the depth. To clarify the origin of this effect, we consider an additional fully connecteddeep normalized residual network with ReLU non-linearities. We form this network from the fullyconnected normalized linear residual network in appendix B by inserting a ReLU non-linearity aftereach normalization layer, and we replace LeCun initialization with He initialization. This network iseasier to analyze than the convolutional network, but similar conclusions hold in both cases.We provide the variance of the hidden activations and the batch statistics of this network in figure 5.The variance on the skip path in this network is approximately equal to the depth of the residual block d , while the variance at the end of the residual branch is approximately 1. This matches exactly ourtheoretical predictions in section 2 of the main text. Notice however that the mean moving variance ofthe batch normalization layer is approximately equal to d (1 − /π ) , while the mean squared movingmean of the normalization layer is approximately equal to d/π . To understand these observations,we note that the outputs of a ReLU non-linearity have non-zero mean, and therefore the ReLU layerwill cause the hidden activations of different examples on the same channel to become correlated (ifthe weights are fixed). Because of this, the variance across multiple examples and multiple channelsbecomes different from the variance across multiple examples for a single fixed channel.To better understand this, we analyze this fully connected normalized ReLU residual network below.The input X ∈ R w × b is a batch of b samples of dimension w that is sampled from a Gaussiandistribution with mean E ( X ij ) = 0 and covariance Cov ( X ik , X jl ) = δ ij δ kl , where δ ij is the diracdelta function. The first dimension corresponds to the features and the second dimension correspondsto the batch. Let W ∈ R w × w denote the linear layer before the first residual block, and for (cid:96) > ,let W (cid:96) denote the linear layer on the residual branch of the (cid:96) -th residual block (we assume that alllayers have the same width w for clarity in presentation). For each weight matrix W (cid:96) , we assume thatthe elements of W (cid:96) are independently sampled from N (0 , /w ) (He initialization). Let X (cid:96) ∈ R w × b denote the input to the (cid:96) -th residual block, let X + = max( X, denote the ReLU non-linearityapplied component-wise, and let B denote the batch normalization operation. Thus, the input to thefirst residual block is given by X = W B ( X ) + , and the output of the (cid:96) -th residual block is givenby X (cid:96) +1 = X (cid:96) + W (cid:96) B ( X (cid:96) ) + for (cid:96) > . We want to analyze the batch normalization statistics foreach layer. To this end, we begin by considering the input to the first residual block X . Note that X ij = (cid:80) k W ik B ( X ) + kj . The mean activation E ( X ij ) = 0 , while the covariance, Cov ( X ij , X lm ) = E (cid:16) (cid:88) kn W ik B ( X ) + kj W ln B ( X ) + nm (cid:17) = 2 w δ il (cid:88) k E (cid:16) B ( X ) + kj B ( X ) + km (cid:17) ≈ δ il (cid:18) π − δ jm π (cid:19) . (1)Since the components of X are independent and Gaussian distributed, we have assumed that thecomponents of B ( X ) are also independent and Gaussian distributed with mean E ( X ij ) = 0 and14 ar ( X ij ) = 1 . This approximation is tight when the batch size is large ( b (cid:29) . It implies that E (cid:0) ( B ( X ) + kj ) (cid:1) ≈ / , and E (cid:0) B ( X ) + kj (cid:1) ≈ (cid:112) / π , from which we arrive at the equation 1.We now consider the input to the second residual block X = X + W B ( X ) + . To considerablysimplify the analysis, we assume that the width w is large ( w (cid:29) ). This implies that X is Gaussiandistributed with the covariance derived in equation 1 (See [13] for details). Once again, this impliesthat if the batch size b is also large then the components of B ( X ) are independent and Gaussiandistributed with mean E ( B ( X ) ij ) = 0 and Var ( B ( X ) ij ) = 1 (note that batch normalization willremove the correlations between different examples in the batch in equation 1). This implies, Cov (( W B ( X ) + ) ij , ( W B ( X ) + ) lm ) ≈ δ il (cid:18) π − δ jm π (cid:19) . Furthermore, note that the covariance between the output of the residual branch and the skip connec-tion,
Cov (( W B ( X ) + ) ij , X lm ) = 0 . We therefore conclude that, Cov ( X ij , X lm ) = Cov ( X ij , X lm ) + Cov (( W B ( X ) + ) ij , ( W B ( X ) + ) lm ) ≈ δ il (cid:18) π − δ jm π (cid:19) . By induction, we can now see that the components of B ( X (cid:96) ) are independent and Gaussian distributedfor all (cid:96) , and Cov (( W (cid:96) B ( X (cid:96) ) + ) ij , X (cid:96)lm ) = 0 for all (cid:96) . Thus, we get, Cov ( X (cid:96)ij , X (cid:96)lm ) ≈ (cid:96)δ il (cid:18) π − δ jm π (cid:19) . We are now ready to compute the expected values of the batch statistics, which we denote by µ (cid:96) and σ (cid:96) (see appendix A). The expected mean squared activation for a batch of examples on a singlechannel (the expected squared BatchNorm moving mean), E (cid:0) ( µ (cid:96)c ) (cid:1) = E (cid:18)(cid:16) b (cid:88) j X (cid:96)cj (cid:17) (cid:19) = 1 b (cid:88) jk E (cid:0) X (cid:96)cj X (cid:96)ck (cid:1) ≈ (cid:96) (cid:18) π + π − πb (cid:19) ≈ (cid:96)/π. Meanwhile the expected variance across a batch of examples on a single channel (the expectedBatchNorm moving variance), E (cid:0) ( σ (cid:96)c ) (cid:1) = E (cid:18) b (cid:88) j ( X (cid:96)cj ) − (cid:16) b (cid:88) j X (cid:96)cj (cid:17) (cid:19) = E (cid:16) b (cid:88) j X (cid:96)cj X (cid:96)cj (cid:17) − E (cid:0) ( µ B (cid:96) c ) (cid:1) ≈ (cid:96) (1 − /π ) . These predictions exactly match our observations in figure 5. Our analysis shows how ReLU non-linearities introduce correlations in the hidden activations between training examples (for sharedrandom weights). These correlations cause the moving variance of the batch normalization layer,which is evaluated on a single channel for a single set of weights, to differ from the variance of thehidden activations over multiple random initializations (which we derived in section 2.1).
D Additional results on CIFAR-10
D.1 Optimal training losses corresponding to table 1
In table 3, we provide the minimum training losses, as well as the optimal learning rates at which thetraining loss is minimized, when training an n -2 Wide-ResNet for a range of depths n on CIFAR-10.At each depth, we train for 200 epochs following the training procedure described in section 2.3 of themain text. These results correspond to the same architectures considered in table 1, where we providedthe associated test set accuracies. We provide the training loss excluding the L2 regularization term(i.e., the training set cross entropy), since one cannot meaningfully compare the L2 regularizationpenalty of normalized and unnormalized networks. These results confirm that batch normalizationand SkipInit achieve similar training losses after the same number of training epochs.15able 3: The training losses, and associated optimal learning rates, of an n -2 Wide-ResNet at a rangeof depths n . We train on CIFAR-10 for 200 epochs with either batch normalization or SkipInit. Batch NormalizationDepth Training loss Learning rate . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) SkipInit ( α = 1 / √ d )Depth Training loss Learning rate . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) SkipInit ( α = 0 )Depth Training loss Learning rate . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) SkipInit ( α = 1 )Depth Training loss Learning rate . ± .
000 2 − (2 − to − ) − − − − Divide residual block by √ Depth Training loss Learning rate . ± .
000 2 − (2 − to − ) . ± .
015 2 − (2 − to − ) − − SkipInit without L2 ( α = 0 )Depth Training loss Learning rate . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) . ± .
000 2 − (2 − to − ) Table 4: The optimal test accuracies, and associated learning rates, of n -2 Wide-ResNets at a rangeof depths n . We train on CIFAR-10 for 200 epochs with different batch-normalized network variants. Divide batch-normalized residual block by √ Depth Test accuracy Learning rate . ± . − (2 − to ) . ± . − (2 − to − ) − − Adding BatchNorm at end of residual blockDepth Test accuracy Learning rate . ± . − (2 − to − ) . ± . (2 − to ) − − Including only the final BatchNorm layerDepth Test accuracy Learning rate . ± . − (2 − to − ) . ± . − (2 − to − ) − − D.2 Variations of batch-normalized residual networks
To further test the theory that the key benefit of batch normalization in deep residual networks isthat it downscales the residual branch at initialization, we now experiment on different variants ofbatch-normalized residual networks. Using the same training setup as in section 2.3, we show in table4 that none of the following schemes are able to train a 1000 layer deep Wide-ResNet on CIFAR-10:• Including batch normalization layers on the residual branch as in a standard Wide-ResNetbut multiplying the residual block by / √ (after the skip path and residual branch merge).• Including batch normalization layers on the residual branch, as well as including an addi-tional batch normalization layer on the skip path.• First removing all batch normalization layers, and then placing a single batch normalizationlayer before the final softmax layer. 16 a) (for test accuracy) (b) (for training loss) Figure 6: The optimal learning rates of SkipInit, Regularized SkipInit and Batch Normalization, for a16-4 Wide-ResNet trained for 200 epochs on CIFAR-10. We evaluate the batch statistics over the fulltraining minibatch. All three methods have similar optimal learning rates in the small batch limit.In all of these experiments, we expect the skip path and the residual branch to contribute equally tothe output of the residual block at initialization. Therefore, as predicted by our theory and confirmedby our experiments in table 4, the network becomes harder to train as the depth increases.
D.3 Optimal learning rates corresponding to figure 4
Finally, in figure 6 we provide the optimal learning rates of SkipInit, Regularized SkipInit and BatchNormalization, when training a 16-4 Wide-ResNet on CIFAR-10. These optimal learning ratescorrespond to the training losses and test accuracies provided in figure 4 of the main text. The batchstatistics for batch normalization layers are evaluated over the full training minibatch.
E Additional results on CIFAR-100
In tables 5 and 6, we provide the optimal test accuracies and optimal training losses, and thecorresponding optimal learning rates, when training n -2 WideResNets on CIFAR-100 for differentdepths n for 200 epochs. We follow the training protocol described in section 2.3 of the main text.Both batch normalization and SkipInit are able to train very deep Wide-ResNets on CIFAR-100.In figure 7, we compare the performance of SkipInit, Regularized SkipInit (drop probability 0.6),and batch normalization across a wide range of batch sizes, when training a 28-10 Wide-ResNeton CIFAR-100 for 200 epochs. We follow the training protocol described in section 3 of the maintext, but we use a ghost batch size of 32. We were not able to train the 28-10 Wide-ResNet tocompetitive performance when not using either batch normalization or SkipInit. Batch normalizednetworks achieve higher test accuracies at all batch sizes. However in the small batch limit, theoptimal learning rate is proportional to the batch size, and the optimal learning rates of all threemethods are approximately equal. As we observed in the main text, batch normalization has a largermaximum stable learning rate, and this allows us to scale training to larger batch sizes.Finally, in figure 8, we repeat this comparison of SkipInit, Regularized SkipInit and batch normal-ization at a range of batch sizes, but instead of selecting a fixed ghost batch size, we evaulate thebatch statistics of batch normalization layers across the full minibatch (as in section 4). We observe aclear regularization effect, whereby the test accuracy achieved with batch normalization peaks for abatch size of 16 and decays rapidly if the batch size is increased or decreased. Regularized SkipInitachieves higher test accuracies than normalized networks when the batch size is small, and it is alsocompetitive with batch normalized networks when the batch size is moderately large. These resultsemphasize the importance of tuning the ghost batch size in batch normalized networks.17able 5: The optimal test accuracies and corresponding learning rates (with error bars), when trainingwidth 2 Wide-ResNets on CIFAR-100 for a wide range of depths. Both batch normalization andSkipInit are able to train very deep residual networks. However it is not possible to train depth 1000networks if we do not downscale the hidden activations on the residual branch at initialization. Batch NormalizationDepth Test accuracy Learning rate . ± . ( − to )100 . ± . ( − to )1000 . ± . ( to ) SkipInit ( α = 1 / √ d )Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − ) SkipInit ( α = 0 )Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − ) SkipInit ( α = 1 )Depth Test accuracy Learning rate . ± . − ( − to − )100 - -1000 - - Divide residual block by √ Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 - - SkipInit without L2 ( α = 0 )Depth Test accuracy Learning rate . ± . − ( − to − )100 . ± . − ( − to − )1000 . ± . − ( − to − )Table 6: The optimal training losses and corresponding learning rates (with error bars), when trainingwidth 2 Wide-ResNets on CIFAR-100 for a wide range of depths. Both batch normalization andSkipInit are able to train very deep residual networks. We show it is not possible to train depth 1000networks if we do not downscale the hidden activations on the residual branch at initialization. Batch NormalizationDepth Training loss Learning rate . ± .
002 2 − ( − to − )100 . ± .
000 2 − ( − to − )1000 . ± .
000 2 ( to ) SkipInit ( α = 1 √ d )Depth Training loss Learning rate . ± .
003 2 − ( − to − )100 . ± .
000 2 − ( − to − )1000 . ± .
000 2 − ( − to − ) SkipInit ( α = 0 )Depth Training loss Learning rate . ± .
007 2 − ( − to − )100 . ± .
000 2 − ( − to − )1000 . ± .
000 2 − ( − to − ) SkipInit ( α = 1 )Depth Training loss Learning rate . ± .
022 2 − ( − to − )100 - -1000 - - Divide residual block by √ Depth Training loss Learning rate . ± .
002 2 − ( − to − )100 . ± .
270 2 − ( − to − )1000 - - SkipInit without L2 ( α = 0 )Depth Training loss Learning rate . ± .
014 2 − ( − to − )100 . ± .
000 2 − ( − to − )1000 . ± .
000 2 − ( − to − )18 a) (b) Figure 7: The optimal test accuracy, and the corresponding optimal learning rates of a 28-10 Wide-ResNet, trained on CIFAR-100 for 200 epochs. We were unable to train this network reliablywithout batch normalization or SkipInit (not shown). Batch normalized networks achieve higher testaccuracies, and are also stable at larger learning rates, which enables large batch training. (a) (b)
Figure 8: The optimal test accuracy, and the corresponding optimal learning rates of a 28-10Wide-ResNet, trained on CIFAR-100 for 200 epochs. We do not use ghost batch normalizationhere, evaluating the batch statistics over the full minibatch. The test accuracy achieved with batchnormalization depends strongly on the batch size, and is maximized for a batch size of 16. RegularizedSkipInit achieves higher test accuracies than batch normalized networks when the batch size is verysmall, and it is competitive with batch normalized networks when the batch size is moderately large.
F Additional results on ImageNet
In table 7, we present the performance of batch normalization, Fixup and Regularized SkipInitwhen training Resnet-50-V1 [2] on ImageNet for 90 epochs. Unlike ResNet-V2 and Wide-ResNets,this network does not have an identity skip path, because it introduces a ReLU at the end of theresidual block after the skip connection and residual branch merge. We find that Fixup performsslightly worse than batch normalization when the batch size is small, but considerably worse thanbatch normalization when the batch size is large (similar to the results on ResNet-50-V2). However,Regularized SkipInit is significantly worse than batch normalization and Fixup at all batch sizes. Thisis not surprising, since we designed SkipInit for models which contain an identity skip connectionthrough the residual block. We also consider a modified version of Regularized SkipInit, whichcontains a single additional scalar bias in each residual block, just before the final ReLU (after theskip connection and residual branch merge). This scalar bias eliminates the gap in validation accuracybetween Fixup and Regularized SkipInit when the batch size is small. We conclude that only twocomponents of Fixup are essential to train the original ResNet-V1: initializing the residual branch atzero, and introducing a scalar bias after the skip connection and residual branch merge.19able 7: We train ResNet50-V1 on ImageNet for 90 epochs. Fixup performs well when the batchsize is small, but performs poorly when the batch size is large. Regularized SkipInit performs poorlyat all batch sizes, but its performance improves considerably if we add a scalar bias before the finalReLU in each residual block (after the skip connection and residual branch merge). We perform agrid search to identify the optimal learning rate which maximizes the top-1 validation accuracy. Weperform a single run at each learning rate and report both top-1 and top-5 accuracy scores. We use adrop probability of . for Regularized SkipInit. We note that ResNet-V1 does not have an identityskip connection, which explains why Regularized SkipInit performs poorly without scalar biases. Batch sizeTest accuracy:
256 1024 4096Batch normalization 75.6 / 92.5 75.3 / 92.4 75.4 / 92.4Fixup 74.4 / 91.6 74.4 / 91.7 72.4 / 90.3Regularized SkipInit 70.0 / 89.2 68.4 / 87.8 68.2 / 87.9Regularized SkipInit + Scalar Bias 75.2 / 92.4 74.9 / 92.0 70.8 / 89.6
G Tensorflow code for ResNet50-V2 with SkipInit
In this section, we provide reference code for our ResNet50-V2 model with Regularized SkipInitusing Sonnet [49] in Tensorflow [36]. import collectionsimport sonnet as sntimport tensorflow as tfResNetBlockParams = collections.namedtuple("ResNetBlockParams", ["output_channels", "bottleneck_channels", "stride"])BLOCKS_50 = ((ResNetBlockParams(256, 64, 1),) * 2 + (ResNetBlockParams(256, 64, 2),),(ResNetBlockParams(512, 128, 1),) * 3 + (ResNetBlockParams(512, 128, 2),),(ResNetBlockParams(1024, 256, 1),) * 5 + (ResNetBlockParams(1024, 256, 2),),(ResNetBlockParams(2048, 512, 1),) * 3)def fixed_padding(inputs, kernel_size, rate=1): """Pads the input along spatial dimensions independently of input size.""" kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)pad_total = kernel_size_effective - 1pad_begin = pad_total // 2pad_end = pad_total - pad_beginpadded_inputs = tf.pad(inputs, [[0, 0], [pad_begin, pad_end],[pad_begin, pad_end], [0, 0]])return padded_inputsdef _max_pool2d_fixed_padding(inputs,kernel_size,stride,padding,scope=None,**kwargs): """Strided 2-D max-pooling with fixed padding (independent of input size).""" if padding == "SAME" and stride > 1:padding = "VALID"inputs = fixed_padding(inputs, kernel_size)return tf.contrib.layers.max_pool2d(inputs, ernel_size,stride=stride,padding=padding,scope=scope,**kwargs)def _conv2d_same(inputs,num_outputs,kernel_size,stride,rate=1,use_bias=None,initializers=None,regularizers=None,partitioners=None): """Strided 2-D convolution with 'SAME' padding.""" if stride == 1:padding = "SAME"else:padding = "VALID"inputs = fixed_padding(inputs, kernel_size, rate)return snt.Conv2D(num_outputs,kernel_size,stride=stride,rate=rate,padding=padding,use_bias=use_bias,initializers=initializers,regularizers=regularizers,partitioners=partitioners)(inputs)class ResNetBlock(snt.AbstractModule): """The ResNet subblock, see https://arxiv.org/abs/1512.03385 for details.""" def __init__(self,output_channels,bottleneck_channels,stride,rate=1,initializers=None,regularizers=None,partitioners=None,name="resnet_block"): """Create a ResNetBlock object for use with the ResNet modules.""" super(ResNetBlock, self).__init__(name=name)self._output_channels = output_channelsself._bottleneck_channels = bottleneck_channelsself._stride = strideself._rate = rateself._initializers = initializersself._regularizers = regularizersself._partitioners = partitionersdef _build(self, inputs, is_training=True, test_local_stats=True): """Connects the ResNetBlock module into the graph.""" um_input_channels = inputs.get_shape()[-1]with tf.variable_scope("preact"):preact = inputspreact = tf.nn.relu(preact)if self._output_channels == num_input_channels:if self._stride == 1:shortcut = inputselse:shortcut = _max_pool2d_fixed_padding(inputs, 1, stride=self._stride, padding="SAME")else:with tf.variable_scope("shortcut"):shortcut = preactshortcut = snt.Conv2D(self._output_channels, [1, 1],stride=self._stride,use_bias=True,initializers=self._initializers,regularizers=self._regularizers,partitioners=self._partitioners)(shortcut)with tf.variable_scope("r1"):residual = snt.Conv2D(self._bottleneck_channels, [1, 1],stride=1,use_bias=True,initializers=self._initializers,regularizers=self._regularizers,partitioners=self._partitioners)(preact)residual = tf.nn.relu(residual)with tf.variable_scope("r2"):residual = _conv2d_same(residual,self._bottleneck_channels,3,self._stride,rate=self._rate,use_bias=True,initializers=self._initializers,regularizers=self._regularizers,partitioners=self._partitioners)residual = tf.nn.relu(residual)with tf.variable_scope("r3"):residual = snt.Conv2D(self._output_channels, [1, 1],stride=1,use_bias=True,initializers=self._initializers,regularizers=self._regularizers,partitioners=self._partitioners)(residual) res_multiplier = tf.Variable(0.0, dtype=tf.float32)residual = res_multiplier*residual utput = shortcut + residualreturn outputdef _build_resnet_blocks(inputs,blocks,initializers=None,regularizers=None,partitioners=None): """Connects the resnet block into the graph.""" outputs = []for num, subblocks in enumerate(blocks):with tf.variable_scope("block_{}".format(num)):for i, block in enumerate(subblocks):args = {"name": "resnet_block_{}".format(i),"initializers": initializers,"regularizers": regularizers,"partitioners": partitioners}args.update(block._asdict())inputs = ResNetBlock(**args)(inputs)outputs += [inputs]return outputsclass ResNetV2(snt.AbstractModule): """ResNet V2 as described in https://arxiv.org/abs/1512.03385.""" def __init__(self,blocks=BLOCKS_50,num_classes=1000,use_global_pool=True,initializers=None,regularizers=None,partitioners=None,custom_getter=None,name="resnet_v2"): """Creates ResNetV2 Sonnet module.""" super(ResNetV2, self).__init__(custom_getter=custom_getter, name=name)self.blocks = tuple(blocks)self._num_classes = num_classesself._use_global_pool = use_global_poolself._initializers = initializersself._regularizers = regularizersself._partitioners = partitionersdef _build(self,inputs,is_training=True,get_intermediate_activations=False): """Connects the ResNetV2 module into the graph.""" outputs = []with tf.variable_scope("root"): nputs = _conv2d_same(inputs,64,7,stride=2,use_bias=True,initializers=self._initializers,regularizers=self._regularizers,partitioners=self._partitioners)inputs = _max_pool2d_fixed_padding(inputs, 3, stride=2, padding="SAME")outputs += [inputs]resnet_outputs = _build_resnet_blocks(inputs,self.blocks,initializers=self._initializers,regularizers=self._regularizers,partitioners=self._partitioners)outputs += resnet_outputs inputs = resnet_outputs[-1]with tf.variable_scope("postnorm"):inputs = tf.nn.relu(inputs)outputs += [inputs]if self._use_global_pool:inputs = tf.reduce_mean(inputs, [1, 2], name="use_global_pool", keepdims=True)outputs += [inputs]inputs = tf.contrib.layers.flatten(inputs)inputs = tf.contrib.layers.dropout(inputs, is_training=is_training,keep_prob=0.8)kernel_initializer = tf.contrib.layers.variance_scaling_initializer()inputs = tf.layers.dense(inputs,self._num_classes,kernel_regularizer=self._regularizers["w"],kernel_initializer=kernel_initializer,name="logits")outputs += [inputs]return outputs if get_intermediate_activations else outputs[-1]inputs = resnet_outputs[-1]with tf.variable_scope("postnorm"):inputs = tf.nn.relu(inputs)outputs += [inputs]if self._use_global_pool:inputs = tf.reduce_mean(inputs, [1, 2], name="use_global_pool", keepdims=True)outputs += [inputs]inputs = tf.contrib.layers.flatten(inputs)inputs = tf.contrib.layers.dropout(inputs, is_training=is_training,keep_prob=0.8)kernel_initializer = tf.contrib.layers.variance_scaling_initializer()inputs = tf.layers.dense(inputs,self._num_classes,kernel_regularizer=self._regularizers["w"],kernel_initializer=kernel_initializer,name="logits")outputs += [inputs]return outputs if get_intermediate_activations else outputs[-1]