An Exponential Learning Rate Schedule for Deep Learning
AA N E XPONENTIAL L EARNING R ATE S CHEDULE FOR D EEP L EARNING
Zhiyuan Li
Princeton University [email protected]
Sanjeev Arora
Princeton University and Institute for Advanced Study [email protected] A BSTRACT
Intriguing empirical evidence exists that deep learning can work well with exoticschedules for varying the learning rate. This paper suggests that the phenomenonmay be due to Batch Normalization or BN(Ioffe & Szegedy, 2015), which is ubiq-uitous and provides benefits in optimization and generalization across all standardarchitectures. The following new results are shown about BN with weight decayand momentum (in other words, the typical use case which was not considered inearlier theoretical analyses of stand-alone BN (Ioffe & Szegedy, 2015; Santurkaret al., 2018; Arora et al., 2018) • Training can be done using SGD with momentum and an exponentially in-creasing learning rate schedule, i.e., learning rate increases by some (1 + α ) factor in every epoch for some α > . (Precise statement in the paper.) Tothe best of our knowledge this is the first time such a rate schedule has beensuccessfully used, let alone for highly successful architectures. As expected,such training rapidly blows up network weights, but the network stays well-behaved due to normalization. • Mathematical explanation of the success of the above rate schedule: a rigor-ous proof that it is equivalent to the standard setting of BN + SGD + StandardRate Tuning + Weight Decay + Momentum. This equivalence holds for othernormalization layers as well, Group Normalization(Wu & He, 2018), LayerNormalization(Ba et al., 2016), Instance Norm(Ulyanov et al., 2016), etc. • A worked-out toy example illustrating the above linkage of hyper-parameters. Using either weight decay or BN alone reaches global minimum,but convergence fails when both are used.
NTRODUCTION
Batch Normalization (BN) offers significant benefits in optimization and generalization across archi-tectures, and has become ubiquitous. Usually best performance is attained by adding weight decayand momentum in addition to BN.Usually weight decay is thought to improve generalization by controlling the norm of the parameters.However, it is fallacious to try to separately think of optimization and generalization because we aredealing with a nonconvex objective with multiple optima. Even slight changes to the training surelylead to a different trajectory in the loss landscape, potentially ending up at a different solution! Oneneeds trajectory analysis to have a hope of reasoning about the effects of such changes.In the presence of BN and other normalization schemes, including GroupNorm, LayerNorm, andInstanceNorm, the optimization objective is scale invariant to the parameters, which means rescalingparameters would not change the prediction, except the parameters that compute the output whichdo not have BN. However, Hoffer et al. (2018b) shows that fixing the output layer randomly doesn’tharm the performance of the network. So the trainable parameters satisfy scale invariance.(Seemore in Appendix C) The current paper introduces new modes of analysis for such settings. Thisrigorous analysis yields the surprising conclusion that the original learning rate (LR) schedule andweight decay(WD) can be folded into a new exponential schedule for learning rate: in each iterationmultiplying it by (1 + α ) for some α > that depends upon the momentum and weight decay rate.1 a r X i v : . [ c s . L G ] N ov heorem 1.1 (Main, Informal) . SGD on a scale-invariant objective with initial learning rate η ,weight decay factor λ , and momentum factor γ is equivalent to SGD with momentum factor γ where at iteration t , the learning rate ˜ η t in the new exponential learning rate schedule is defined as ˜ η t = α − t − η without weight decay( ˜ λ = 0 ) where α is a non-zero root of equation x − (1 + γ − λη ) x + γ = 0 , (1)Specifically, when momentum γ = 0 , the above schedule can be simplified as ˜ η t = (1 − λη ) − t − η .The above theorem requires that the product of learning rate and weight decay factor, λη , is smallthan (1 − √ γ ) , which is almost always satisfied in practice. The rigorous and most general versionof above theorem is Theorem 2.12, which deals with multi-phase LR schedule, momentum andweight decay.There are other recently discovered exotic LR schedules, e.g. Triangular LR schedule(Smith, 2017)and Cosine LR schedule(Loshchilov & Hutter, 2016), and our exponential LR schedule is an extremeexample of LR schedules that become possible in presence of BN. Such an exponential increase inlearning rate seems absurd at first sight and to the best of our knowledge, no deep learning successhas been reported using such an idea before. It does highlight the above-mentioned viewpoint thatin deep learning, optimization and regularization are not easily separated. Of course, the exponenttrumps the effect of initial lr very fast (See Figure 3), which explains why training with BN andWD is not sensitive to the scale of initialization, since with BN, tuning the scale of initialization isequivalent to tuning the initial LR η while fixing the product of LR and WD, ηλ (See Lemma 2.7).Note that it is customary in BN to switch to a lower LR upon reaching a plateau in the validationloss. According to the analysis in the above theorem, this corresponds to an exponential growthwith a smaller exponent, except for a transient effect when a correction term is needed for the twoprocesses to be equivalent (see discussion around Theorem 2.12).Thus the final training algorithm is roughly as follows: Start from a convenient LR like . , andgrow it at an exponential rate with a suitable exponent. When validation loss plateaus, switch toan exponential growth of LR with a lower exponent. Repeat the procedure until the training losssaturates. In Section 3, we demonstrate on a toy example how weight decay and normalization are inseparablyinvolved in the optimization process. With either weight decay or normalization alone, SGD willachieve zero training error. But with both turned on, SGD fails to converge to global minimum.In Section 5, we experimentally verify our theoretical findings on CNNs and ResNets. We alsoconstruct better exponential LR schedules by incorporating the Cosine LR schedule on CIFAR10,which opens the possibility of even more general theory of rate schedule tuning towards betterperformance.1.1 R
ELATED W ORK
There have been other theoretical analyses of training models with scale-invariance. (Cho & Lee,2017) proposed to run Riemanian gradient descent on Grassmann manifold G (1 , n ) since the weightmatrix is scaling invariant to the loss function. observed that the effective stepsize is proportional to η w (cid:107) w t (cid:107) . (Arora et al., 2019) show the gradient is always perpendicular to the current parameter vectorwhich has the effect that norm of each scale invariant parameter group increases monotonically,which has an auto-tuning effect. (Wu et al., 2018) proposes a new adaptive learning rate schedulemotivated by scale-invariance property of Weight Normalization. Previous work for understanding Batch Normalization. (Santurkar et al., 2018) suggested thatthe success of BNhas does not derive from reduction in Internal Covariate Shift, but by makinglandscape smoother. (Kohler et al., 2018) essentially shows linear model with BN could achieveexponential convergence rate assuming gaussian inputs, but their analysis is for a variant of GDwith an inner optimization loop rather than GD itself. (Bjorck et al., 2018) observe that the higherlearning rates enabled by BN empirically improves generalization. (Arora et al., 2019) prove thatwith certain mild assumption, (S)GD with BN finds approximate first order stationary point with anyfixed learning rate. None of the above analyses incorporated weight decay, but (Zhang et al., 2019;Hoffer et al., 2018a; van Laarhoven, 2017; Page; Wu) argued qualitatively that weight decay makes2arameters have smaller norms, and thus the effective learning rate, η w (cid:107) w t (cid:107) is larger. They describedexperiments showing this effect but didn’t have a closed form theoretical analysis like ours. Noneof the above analyses deals with momentum rigorously.1.2 P RELIMINARIES AND N OTATIONS
For batch B = { x i } Bi =1 , network parameter θ , we denote the network by f θ and the loss function atiteration t by L t ( f θ ) = L ( f θ , B t ) . When there’s no ambiguity, we also use L t ( θ ) for convenience.We say a loss function L ( θ ) is scale invariant to its parameter θ is for any c ∈ R + , L ( θ ) = L ( c θ ) . In practice, the source of scale invariance is usually different types of normalization layers,including Batch Normalization (Ioffe & Szegedy, 2015), Group Normalization (Wu & He, 2018),Layer Normalization (Ba et al., 2016), Instance Norm (Ulyanov et al., 2016), etc.Implementations of SGD with Momentum/Nesterov comes with subtle variations in literature. Weadopt the variant from Sutskever et al. (2013), also the default in PyTorch (Paszke et al., 2017). L regularization (a.k.a. Weight Decay ) is another common trick used in deep learning. Combiningthem together, we get the one of the mostly used optimization algorithms below.
Definition 1.2. [SGD with Momentum and Weight Decay] At iteration t , with randomly sampledbatch B t , update the parameters θ t and momentum v t as following: θ t = θ t − − η t − v t (2) v t = γ v t − + ∇ θ (cid:18) L t ( θ t − ) + λ t − (cid:107) θ t − (cid:107) (cid:19) , (3)where η t is the learning rate at epoch t , γ is the momentum coefficient, and λ is the factor of weightdecay. Usually, v is initialized to be .For ease of analysis, we will use the following equivalent of Definition 1.2. θ t − θ t − η t − = γ θ t − − θ t − η t − − ∇ θ (cid:18) ( L ( θ t − ) + λ t − (cid:107) θ t − (cid:107) (cid:19) , (4)where η − and θ − must be chosen in a way such that v = θ − θ − η − is satisfied, e.g. when v = , θ − = θ and η − could be arbitrary.A key source of intuition is the following simple lemma about scale-invariant networks Aroraet al. (2019). The first property ensures GD (with momentum) always increases the norm of theweight.(See Lemma B.1 in Appendix B) and the second property says that the gradients are smallerfor parameteres with larger norm, thus stabilizing the trajectory from diverging to infinity. Lemma 1.3 (Scale Invariance) . If for any c ∈ R + , L ( θ ) = L ( c θ ) , then(1). (cid:104)∇ θ L, θ (cid:105) = 0 ;(2). ∇ θ L (cid:12)(cid:12) θ = θ = c ∇ θ L (cid:12)(cid:12) θ = c θ , for any c > ERIVING E XPONENTIAL L EARNING R ATE S CHEDULE
As a warm-up in Section 2.1 we show that if momentum is turned off then
Fixed LR + Fixed WD can be translated to an equivalent
Exponential LR.
In Section 2.2 we give a more general analysison the equivalence between
Fixed LR + Fixed WD + Fixed Momentum Factor and
ExponentialLR + Fixed Momentum Factor . While interesting, this still does completely apply to real-life deeplearning where reaching full accuracy usually requires multiple phases in training where LR is fixedwithin a phase and reduced by some factor from one phase to the next. Section 2.3 shows how tointerpret such a multi-phase LR schedule + WD + Momentum as a certain multi-phase exponentialLR schedule with Momentum.2.1 R
EPLACING WD BY E XPONENTIAL LR IN M OMENTUM -F REE
SGDWe use notation of Section 1.2 and assume LR is fixed over iterations, i.e. η t = η , and γ (momen-tum factor) is set as . We also use λ to denote WD factor and θ to denote the initial parameters.3he intuition should be clear from Lemma 1.3, which says that shrinking parameter weights by factor ρ (where ρ < ) amounts to making the gradient ρ − times larger without changing its direction.Thus in order to restore the ratio between original parameter and its update (LR × Gradient), theeasiest way would be scaling LR by ρ . This suggests that scaling the parameter θ by ρ at each stepis equivalent to scaling the LR η by ρ − .To prove this formally we use the following formalism. We’ll refer to the vector ( θ , η ) the state ofa training algorithm and study how this evolves under various combinations of parameter changes.We will think of each step in training as a mapping from one state to another. Since mappings can becomposed, any finite number of steps also correspond to a mapping. The following are some basicmappings used in the proof.1. Run GD with WD for a step: GD ρt ( θ , η ) = ( ρ θ − η ∇ L t ( θ ) , η ) ;2. Scale the parameter θ : Π c ( θ , η ) = ( c θ , η ) ;3. Scale the LR η : Π c ( θ , η ) = ( θ , cη ) .For example, when ρ = 1 , GD t is vanilla GD update without WD, also abbreviated as GD t . When ρ = 1 − λη , GD − λη t is GD update with WD λ and LR η . Here L t is the loss function at iteration t , which is decided by the batch of the training samples B t in t th iteration. Below is the main resultof this subsection, showing our claim that GD + WD ⇔ GD+ Exp LR (when Momentum is zero).It will be proved after a series of lemmas.
Theorem 2.1 (WD ⇔ Exp LR) . For every ρ < and positive integer t following holds:GD ρt − ◦ · · · ◦ GD ρ = (cid:104) Π ρ t ◦ Π ρ t (cid:105) ◦ Π ρ − ◦ GD t − ◦ Π ρ − ◦ · · · ◦ GD ◦ Π ρ − ◦ GD ◦ Π ρ − . With WD being λ , ρ is set as − λη and thus the scaling factor of LR per iteration is ρ − =(1 − λη ) − , except for the first iteration it’s ρ − = (1 − λη ) − .We first show how to write GD update with WD as a composition of above defined basic maps. Lemma 2.2. GD ρt = Π ρ ◦ Π ρ ◦ GD t ◦ Π ρ − . Below we will define the proper notion of equivalence such that (1). Π ρ ∼ Π ρ − , which impliesGD ρt ∼ Π ρ − ◦ GD t ◦ Π ρ − ; (2) the equivalence is preserved under future GD updates.We first extend the equivalence between weights (same direction) to that between states, with addi-tional requirement that the ratio between the size of GD update and that of parameter are the sameamong all equivalent states, which yields the notion of Equivalent Scaling . Definition 2.3 (Equivalent States) . ( θ , η ) is equivalent to ( θ (cid:48) , η (cid:48) ) iff ∃ c > , ( (cid:101) θ , (cid:101) η ) = [Π c ◦ Π c ]( θ , η ) = ( c θ , c η ) , which is also denoted by ( (cid:101) θ , (cid:101) η ) c ∼ ( θ , η ) . Π c ◦ Π c is called Equiva-lent Scaling for all c > .The following lemma shows that equivalent scaling commutes with GD update with WD, implyingthat equivalence is preserved under GD update (Lemma 2.4). This anchors the notion of equiv-alence — we could insert equivalent scaling anywhere in a sequence of basic maps(GD update,LR/parameter scaling), without changing the final network. Lemma 2.4.
For any constant c, ρ > and t ≥ , GD ρt ◦ [Π c ◦ Π c ] = [Π c ◦ Π c ] ◦ GD ρt .In other words, ( θ , η ) c ∼ ( θ (cid:48) , η (cid:48) ) = ⇒ GD ρt ( θ , η ) c ∼ GD ρt ( θ (cid:48) , η (cid:48) ) . Now we formally define equivalence relationship between maps using equivalent scalings.
Definition 2.5 (Equivalent Maps) . Two maps
F, G are equivalent iff ∃ c > , F = Π c ◦ Π c ◦ G ,which is also denoted by F c ∼ G . Proof of Theorem 2.1.
By Lemma 2.2,, GD ρt ρ ∼ Π ρ − ◦ GD t ◦ Π ρ − . By Lemma 2.4, GD updatepreserves map equivalence, i.e. F c ∼ G ⇒ GD ρt ◦ F c ∼ GD ρt ◦ G, ∀ c, ρ > . Thus,GD ρt − ◦ · · · ◦ GD ρ ρ t ∼ Π ρ − ◦ GD t − ◦ Π ρ − ◦ · · · ◦ GD ◦ Π ρ − ◦ GD ◦ Π ρ − . igure 1: Taking PreResNet32 with standard hyperparameters and replacing WD during first phase (Fixed LR)by exponential LR according to Theorem 2.9 to the schedule (cid:101) η t = 0 . × . t , momentum . . Plot onright shows weight norm w of the first convolutional layer in the second residual block grows exponentially,satisfying (cid:107) w t (cid:107) (cid:101) η t = constant. Reason being that according to the proof it is essentially the norm square ofthe weights when trained with Fixed LR + WD + Momentum, and published hyperparameters kept this normroughly constant during training. EPLACING WD BY E XPONENTIAL
LR: C
ASE OF CONSTANT LR WITH MOMENTUM
In this subsection the setting is the same to that in Subsection 2.1 except that the momentum fac-tor is γ instead of 0. Suppose the initial momentum is v , we set θ − = θ − v η . Presenceof momentum requires representing the state of the algorithm with four coordinates, ( θ , η, θ (cid:48) , η (cid:48) ) ,which stand respectively for the current parameters/LR and the buffered parameters/LR (from lastiteration) respectively. Similarly, we define the following basic maps and equivalence relationships.1. Run GD with WD for a step: GD ρt ( θ , η, θ (cid:48) , η (cid:48) ) = (cid:16) ρ θ + η (cid:16) γ θ − θ (cid:48) η (cid:48) − ∇ L t ( θ ) (cid:17) , η, θ , η (cid:17) ;2. Scale Current parameter θ Π c ( θ , η, θ (cid:48) , η (cid:48) ) = ( c θ , η, θ (cid:48) , η (cid:48) ) ;3. Scale Current LR η : Π c ( θ , η, θ (cid:48) , η (cid:48) ) = ( θ , cη, θ (cid:48) , η (cid:48) ) ;4. Scale Buffered parameter θ (cid:48) : Π c ( θ , η, θ (cid:48) , η (cid:48) ) = ( θ , η, c θ (cid:48) , η (cid:48) ) ;5. Scale Buffered parameter η (cid:48) : Π c ( θ , η, θ (cid:48) , η (cid:48) ) = ( θ , η, θ (cid:48) , cη (cid:48) ) . Definition 2.6 (Equivalent States) . ( θ , η, θ (cid:48) , η (cid:48) ) is equivalent to ( (cid:101) θ , (cid:101) η, (cid:101) θ (cid:48) , (cid:101) η (cid:48) ) iff ∃ c > , ( θ , η, θ (cid:48) , η (cid:48) ) = (cid:104) Π c ◦ Π c ◦ Π c ◦ Π c (cid:105) ( (cid:101) θ , (cid:101) η, (cid:101) θ (cid:48) , (cid:101) η (cid:48) ) = ( c (cid:101) θ , c (cid:101) η, c (cid:101) θ (cid:48) , c (cid:101) η (cid:48) ) , which is also denotedby ( θ , η, θ (cid:48) , η (cid:48) ) c ∼ ( (cid:101) θ , (cid:101) η, (cid:101) θ (cid:48) , (cid:101) η (cid:48) ) . We call Π c ◦ Π c ◦ Π c ◦ Π c Equivalent Scalings for all c > .Again by expanding the definition, we show equivalent scalings commute with GD update. Lemma 2.7. ∀ c, ρ > and t ≥ , GD ρt ◦ (cid:104) Π c ◦ Π c ◦ Π c ◦ Π c (cid:105) = (cid:104) Π c ◦ Π c ◦ Π c ◦ Π c (cid:105) ◦ GD ρt . Similarly, we can rewrite GD ρt as a composition of vanilla GD update and other scalings by expand-ing the definition, when the current and buffered LR are the same in the input of GD ρt . Lemma 2.8.
For any input ( θ , η, θ (cid:48) , η ) , if α > is a root of α + γα − = ρ + γ , thenGD ρt ( θ , η, θ (cid:48) , η ) = (cid:104) Π α ◦ Π α ◦ Π α ◦ GD t ◦ Π α − ◦ Π α ◦ Π α (cid:105) ( θ , η, θ (cid:48) , η ) . In other words,GD ρt ( θ , η, θ (cid:48) , η ) α ∼ (cid:104) Π α − ◦ Π α − ◦ Π α − ◦ GD t ◦ Π α − ◦ Π α ◦ Π α (cid:105) ( θ , η, θ (cid:48) , η ) . (5)Though looking complicated, the RHS of Equation 5 is actually the desired Π α − ◦ GD t ◦ Π α − conjugated with some scaling on momentum part Π α ◦ Π α , and Π α − ◦ Π α − in the current updatecancels with the Π α ◦ Π α in the next update. Now we are ready to show the equivalence betweenWD and Exp LR schedule when momentum is turned on for both. Theorem 2.9 (GD + WD ⇔ GD+ Exp LR; With Momentum) . The following defined two sequencesof parameters , { θ t } ∞ t =0 and { (cid:101) θ t } ∞ t =0 , satisfy (cid:101) θ t = α t θ t , thus they correspond to the same networksin function space, i.e. f θ t = f (cid:101) θ t , ∀ t ∈ N , given (cid:101) θ = θ , (cid:101) θ − = θ − α , and (cid:101) η t = η α − t − .1. θ t − θ t − η = γ ( θ t − − θ t − ) η − ∇ θ ( L ( θ t − ) + λ (cid:107) θ t − (cid:107) ) (cid:101) θ t − (cid:101) θ t − (cid:101) η t = γ ( (cid:101) θ t − − (cid:101) θ t − ) (cid:101) η t − − ∇ θ L ( (cid:101) θ t − ) igure 2: PreResNet32 trained with standard Step Decay and its corresponding Tapered-Exponential LR sched-ule. As predicted by Theorem 2.12, they have similar trajectories and performances. where α is a positive root of equation x − (1 + γ − λη ) x + γ = 0 , which is always smaller than1(See Appendix A.1). When γ = 0 , α = 1 − λη is the unique non-zero solution. Remark 2.10.
Above we implicitly assume that λη ≤ (1 − √ γ ) such that the roots are real andthis is always true in practice. For instance of standard hyper-parameters where γ = 0 . , η =0 . , λ = 0 . , λη (1 −√ γ ) ≈ . (cid:28) . Proof.
Note that ( (cid:101) θ , (cid:101) η , (cid:101) θ − , (cid:101) η − ) = (cid:104) Π α − ◦ Π α ◦ Π α (cid:105) ( θ , η , θ , η ) , it suffices to show that (cid:104) Π α − ◦ Π α − ◦ Π α − ◦ GD t − ◦ Π α − ◦ · · · ◦ GD ◦ Π α − ◦ GD ◦ Π α − ◦ Π α ◦ Π α (cid:105) ( θ , η , θ , η ) α t ∼ GD − λη t − ◦ · · · ◦ GD − λη ( θ , η , θ , η ) , ∀ t ≥ . which follows immediately from Lemma 2.7 and Lemma 2.8 by induction.2.3 R EPLACING WD BY E XPONENTIAL
LR: C
ASE OF MULTIPLE LR PHASES
Usual practice in deep learning shows that reaching full training accuracy requires reducing thelearning rate a few times.
Definition 2.11.
Step Decay is the (standard) learning rate schedule, where training has K phases I = 0 , , . . . , K − , where phase I starts at iteration T I ( T = 0 ), and all iterations in phase I usea fixed learning rate of η ∗ I .The algorithm state in Section 2.2, consists of 4 components including buffered and current LR.When LR changes, the buffered and current LR are not equal, and thus Lemma 2.8 cannot be appliedany more. In this section we show how to fix this issue by adding extra momentum correction.In detail, we show the below defined Exp LR schedule leads the same trajectory of networks infunction space, with one-time momentum correction at the start of each phase. We empirically findon CIFAR10 that ignoring the correction term does not change performance much. Theorem 2.12 (Tapered-Exponential LR Schedule) . There exists a way to correct the momentumonly at the first iteration of each phase, such that the following Tapered-Exponential LR schedule(TEXP) { (cid:101) η t } with momentum factor γ and no WD, leads the same sequence networks in functionspace as that of Step Decay LR schedule(Definition 2.11) with momentum factor γ and WD λ . (cid:101) η t = (cid:40)(cid:101) η t − × ( α ∗ I − ) − if T I − + 1 ≤ t ≤ T I − , I ≥ (cid:101) η t − × η ∗ I η ∗ I − × ( α ∗ I ) − ( α ∗ I − ) − if t = T I , I ≥ , (6)where α ∗ I = γ − λη ∗ I + (cid:113) ( γ − λη ∗ I ) − γ , (cid:101) η = η · ( α ∗ ) − = η ∗ · ( α ∗ ) − .The analysis in previous subsection give the equivalence within each phase, where the same LRis used throughout the phase. To deal with the difference between buffered LR and current LRwhen entering new phases, the idea is to pretend η t − = η t and θ t − becomes whatever it needsto maintain θ t − θ t − η t − such that we can again apply Lemma 2.8, which requires the current LR ofthe input state is equal to its buffered LR. Because scaling α in RHS of Equation 5 is differentin different phases, so unlike what happens within each phase, they don’t cancel with each otherat phase transitions, thus remaining as a correction of the momentum. The proofs are delayed toAppendix A, where we proves a more general statement allowing phase-dependent WD, { λ I } K − I =0 .6 lternative interpretation of Step Decay to exponential LR schedule: Below we present a newLR schedule,
TEXP++ , which is exactly equivalent to
Step Decay without the need of one-timecorrection of momentum when entering each phase. We further show in Appendix A.1 that whentranslating from
Step Decay , the
TEXP++ we get is very close to the original
TEXP (Equation 9),i.e. the ratio between the LR growth per round, (cid:101) η t +1 (cid:101) η t / (cid:101) η (cid:48) t +1 (cid:101) η (cid:48) t converges to 1 exponentially each phase.For example, with WD 0.0005, max LR 0.1, momentum factor 0.9, the ratio is within ± . ∗ . t − T I , meaning TEXP and TEXP++ are very close for Step Decay with standard hyperparameters. Theorem 2.13.
The following two sequences of parameters , { θ t } ∞ t =0 and { (cid:101) θ t } ∞ t =0 , define the samesequence of network functions, i.e. f θ t = f (cid:101) θ t , ∀ t ∈ N , given the initial conditions, (cid:101) θ = P θ , (cid:101) θ − = P − θ − .1. θ t − θ t − η t − = γ θ t − − θ t − η t − − ∇ θ (cid:16) ( L ( θ t − ) + λ t − (cid:107) θ t − (cid:107) (cid:17) , for t = 1 , , . . . ;2. (cid:101) θ t − (cid:101) θ t − (cid:101) η t − = γ (cid:101) θ t − − (cid:101) θ t − (cid:101) η t − − ∇ θ L ( (cid:101) θ t − ) , for t = 1 , , . . . ,where (cid:101) η t = P t P t +1 η t , P t = t (cid:81) i = − α − i , ∀ t ≥ − and α t recursively defined as α t = − η t − λ t − + 1 + η t − η t − γ (1 − α − t − ) , ∀ t ≥ . (7)The LR schedule { (cid:101) η t } ∞ t =0 is called Tapered Exponential ++ , or
TEXP++ . XAMPLE ILLUSTRATING INTERPLAY OF WD AND BN The paper so far has shown that effects of different hyperparameters in training are not easily sep-arated, since their combined effect on the trajectory is complicated. We give a simple exampleto illustrate this, where convergence is guaranteed if we use either BatchNorm or weight decay inisolation, but convergence fails if both are used. (Momentum is turned off for clarity of presentation)
Setting:
Suppose we are fine-tuning the last linear layer of the network, where the input of the lastlayer is assumed to follow a standard Gaussian distribution N (0 , I m ) , and m is the input dimensionof last layer. We also assume this is a binary classification task with logistic loss, l ( u, y ) = ln(1 +exp( − uy )) , where label y ∈ {− , } and u ∈ R is the output of the neural network. The trainingalgorithm is SGD with constant LR and WD, and without momentum. For simplicity we assumethe batch size B is very large so we could assume the covariance of each batch B t concentrates andis approximately equal to identity, namely B (cid:80) Bi =1 x t,b x (cid:62) t,b ≈ I m . We also assume the the input ofthe last layer are already separable, and w.l.o.g. we assume the label is equal to the sign of the firstcoordinate of x ∈ R m , namely sign ( x ) . Thus the training loss and training error are simply L ( w ) = E x ∼N (0 ,I m ) ,y = sign ( x ) (cid:2) ln(1 + exp( − x (cid:62) w y )) (cid:3) , Pr x ∼N (0 ,I m ) ,y = sign ( x ) (cid:2) x (cid:62) w y ≤ (cid:3) = 1 π arccos w (cid:107) w (cid:107) Case 1: WD alone:
Since both the above objective with L2 regularization is strongly convex andsmooth in w , vanilla GD with suitably small learning rate could get arbitrarily close to the globalminimum for this regularized objective. In our case, large batch SGD behaves similarly to GD andcan achieve O ( (cid:113) ηλB ) test error following the standard analysis of convex optimization. Case 2: BN alone:
Add a BN layer after the linear layer, and fix scalar and bias term to 1 and 0. Theobjective becomes L BN ( w ) = E x ∼N (0 ,I m ) ,y = sign ( x ) [ L BN ( w , x )] = E x ∼N (0 ,I m ) ,y = sign ( x ) (cid:20) ln(1 + exp( − x (cid:62) w (cid:107) w (cid:107) y )) (cid:21) . From Appendix A.6, there’s some constant C , such that ∀ w ∈ R m with constant prob-ability, (cid:107)∇ w L BN ( w , x ) (cid:107) ≥ C (cid:107) w (cid:107) . By Pythagorean Theorem, (cid:107) w t +1 (cid:107) = ( (cid:107) w t (cid:107) + η (cid:107)∇ w L BN ( w t , x ) (cid:107) ) ≥ (cid:107) w t (cid:107) + 2 η (cid:107) w t (cid:107) (cid:107)∇ w L BN ( w t , x ) (cid:107) . As a result, for any fixedlearning rate , (cid:107) w t +1 (cid:107) ≥ (cid:80) ti =1 η (cid:107) w (cid:107) (cid:107)∇ w L BN ( w i , x ) (cid:107) grows at least linearly with highprobability. Following the analysis of Arora et al. (2019), this is like reducing the effective learning7ate, and when (cid:107) w t (cid:107) is large enough, the effective learning rate is small enough, and thus SGD canfind the local minimum, which is the unique global minimum. Case 3: Both BN and WD:
When BN and WD are used together, no matter how small the noise is,which comes from the large batch size, the following theorem shows that SGD will not converge toany solution with error smaller than O ( √ ηλ ) , which is independent of the batch size (noise level). Theorem 3.1. [Nonconvergence] Starting from iteration any T , with probability − δ over therandomness of samples, the training error will be larger than επ at least once for the followingconsecutive ηλ − ε ) ln (cid:107) w T (cid:107) ε √ Bη √ m − + 9 ln δ iterations. Sketch. (See full proof in Appendix A.) The high level idea of this proof is that if the test error islow, the weight is restricted in a small cone around the global minimum, and thus the amount of thegradient update is bounded by the size of the cone. In this case, the growth of the norm of the weightby Pythagorean Theorem is not large enough to cancel the shrinkage brought by weight decay. As aresult, the norm of the weight converges to 0 geometrically. Again we need to use the lower boundfor size of the gradient, that (cid:107)∇ w L t (cid:107) = Θ( η (cid:107) w t (cid:107) (cid:112) mB ) holds with constant probability. Thus thesize of the gradient will grow along with the shrinkage of (cid:107) w t (cid:107) until they’re comparable, forcingthe weight to leave the cone in next iteration. IEWING
EXP LR
VIA C ANONICAL O PTIMIZATION F RAMEWORK
This section tries to explain why the efficacy of exponential LR in deep learning is mysterious to us,at least as viewed in the canonical framework of optimization theory.
Canonical framework for analysing 1st order methods
This focuses on proving that each —ormost—steps of GD noticeably reduce the objective, by relying on some assumption about the spec-trum norm of the hessian of the loss, and most frequently, the smoothness , denoted by β . Specifically,for GD update θ t +1 = θ t − η ∇ L ( θ t ) , we have L ( θ t +1 ) − L ( θ t ) ≤ ( θ t +1 − θ t ) (cid:62) ∇ L ( θ t ) + β (cid:107) θ t +1 − θ t (cid:107) = − η (1 − βη (cid:107)∇ L ( θ t ) (cid:107) . When β < η , the first order term is larger than the second order one, guaranteeing the loss valuedecreases. Since the analysis framework treats the loss as a black box (apart from the assumedbounds on the derivative norms), and the loss is non-convex, the best one can hope for is to provespeedy convergence to a stationary point (where gradient is close to ). An increasing body of workproves such results.Now we turn to difficulties in understanding the exponential LR in context of the above frameworkand with scale-invariance in the network.1. Since loss is same for θ and c · θ for all c > a simple calculation shows that along anystraight line through the origin, smoothness is a decreasing function of c , and is very highclose to origin. (Note: it is also possible to one can show the following related fact: In anyball containing the origin, the loss is nonconvex.)Thus if one were trying to apply the canonical framework to argue convergence to a sta-tionary point, the natural idea would be to try to grow the norm of the parameters untilsmoothness drops enough that the above-mentioned Canonical Framework starts to ap-ply. Arora et al. (2019) showed this happens in GD with fixed LR (WD turned off), andfurthermore the resulting convergence rate to stationary point is asymptotically similar toanalyses of nonconvex optimization with learning rate set as in the
Canonical framework .Santurkar et al. (2018) observed similar phenomenon in experiments, which they describedas a smoothening of the objective due to BN.2. The
Canonical Framework can be thought of as a discretization of continuous gradientdescent (i.e., gradient flow): in principle it is possible to use arbitrarily small learning rate,but one uses finite learning rate merely to keep the number of iterations small. The discreteprocess approximates the continuous process due to smoothness being small.In case of gradient flow with weight decay (equivalently, with exponential LR schedule) thediscrete process cannot track the continuous process for very long, which suggests that any8xplanation of the benefits of exponential LR may need to rely on discrete process beingsomehow better. The reason being that for gradient flow one can decouple the speed of the θ t into the tangential and the radial components, where the former one has no effect onthe norm and the latter one has no effect on the objective but scales the tangential gradientexponentially. Thus the Gradient Flow with WD gives exactly the same trajectory as vanillaGradient Flow does, excepting a exponential reparametrization with respect to time t .3. It can be shown that if the local smoothness is upperbounded by η (as stipulated in Canon-ical Framework ) during a sequence θ t ( t = 1 , , . . . ) of GD updates with WD and constantLR then such sequence satisfies θ t → . This contrasts with the usual experimental obser-vation that θ t stays bounded away from . One should thus conclude that in practice, withconstant LR and WD, smoothness doesn’t always stay small (unlike the above analyseswhere WD is turned off). XPERIMENTS
The translation to exponential LR schedule is exact except for one-time momentum correction termentering new phases. The experiments explore the effect of this correction term. The TaperedExponential(TEXP) LR schedule contains two parts when entering a new phase I: an instant LRdecay ( η I η I − ) and an adjustment of the growth factor ( α ∗ I − → α ∗ I ). The first part is relative smallcompared to the huge exponential growing. Thus a natural question arises: Can we simplify TEXPLR schedule by dropping the part of instant LR decay?
Also, previously we have only verified our equivalence theorem in Step Decay LR schedules. But it’snot sure how would the Exponential LR schedule behave on more rapid time-varying LR schedulessuch as Cosine LR schedule.
Settings:
We train PreResNet32 on CIFAR10. The initial learning rate is 0.1 and the momentumis 0.9 in all settings. We fix all the scalar and bias of BN, because otherwise they together with thefollowing conv layer grow exponentially, sometimes exceeding the range of Float32 when trainedwith large growth rate for a long time. We fix the parameters in the last fully connected layer forscale invariance of the objective.5.1 T
HE BENEFIT OF INSTANT LR DECAY
We tried the following LR schedule (we call it
TEXP-- ). Interestingly, up to correction of momentumwhen entering a new phase, this schedule is equivalent to a constant LR schedule, but with theweight decay coefficient reduced correspondingly at the start of each phase. (See Theorem A.2 andFigure 5)TEXP--: (cid:101) η t +1 = (cid:26)(cid:101) η t × ( α ∗ I − ) − if T I − + 1 ≤ t ≤ T I − , I ≥ (cid:101) η t × ( α ∗ I ) − ( α ∗ I − ) − if t = T I , I ≥ , (8)where α ∗ I = γ − λη ∗ I + (cid:113) ( γ − λη ∗ I ) − γ , (cid:101) η = η · ( α ∗ ) − = η ∗ · ( α ∗ ) − . Figure 3:
Instant LR decay has only temporary effect when LR growth (cid:101) η t / (cid:101) η t − − is large. The blue line usesan exponential LR schedule with constant exponent. The orange line multiplies its LR by the same constanteach iteration, but also divide LR by 10 at the start of epoch 80 and 120. The instant LR decay only allows theparameter to stay at good local minimum for 1 epoch and then diverges, behaving similarly to the trajectorieswithout no instant LR decay. igure 4: Instant LR decay is crucial when LR growth (cid:101) η t / (cid:101) η t − − is very small. The original LR of Step Decayis decayed by 10 at epoch , respectively. In the third phase, LR growth (cid:101) η t / (cid:101) η t − − is approximately 100times smaller than that in the third phase, it would take TEXP-- hundreds of epochs to reach its equilibrium. Asa result, TEXP achieves better test accuracy than TEXP--. As a comparison, in the second phase, (cid:101) η t / (cid:101) η t − − is only 10 times smaller than that in the first phase and it only takes 70 epochs to return to equilibrium. Figure 5:
The orange line corresponds to PreResNet32 trained with constant LR and WD divided by 10 atepoch 80 and 120. The blue line is TEXP-- corresponding to Step Decay schedule which divides LR by 10 atepoch 80 and 120. They have similar trajectories and performances by a similar argument to Theorem 2.12.(SeeTheorem A.2 and its proof in Appendix A)
ETTER E XPONENTIAL
LR S
CHEDULE WITH C OSINE
LRWe applied the TEXP LR schedule (Theorem 2.12) on the Cosine LR schedule (Loshchilov &Hutter, 2016), where the learning rate changes every epoch, and thus correction terms cannot beignored. The LR at epoch t ≤ T is defined as: η t = η tT π )2 . Our experiments show thishybrid schedule with Cosine LR performs better on CIFAR10 than Step Decay, but this findingneeds to be verified on other datasets. ONCLUSIONS
The paper shows rigorously how BN allows a host of very exotic learning rate schedules in deeplearning, and verifies these effects in experiments. The lr increases exponentially in almost everyiteration during training. The exponential increase derives from use of weight decay, but the precise
Figure 6:
Both Cosine and Step Decay schedule behaves almost the same as their exponential counterpart, aspredicted by our equivalence theorem. The (exponential) Cosine LR schedule achieves better test accuracy,with a entirely different trajectory. R EFERENCES
Sanjeev Arora, Nadav Cohen, and Elad Hazan. On the optimization of deep networks: Implicitacceleration by overparameterization. In
International Conference on Machine Learning , pp.244–253, 2018.Sanjeev Arora, Zhiyuan Li, and Kaifeng Lyu. Theoretical analysis of auto rate-tuning by batchnormalization. In
International Conference on Learning Representations , 2019. URL https://openreview.net/forum?id=rkxQ-nA9FX .Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprintarXiv:1607.06450 , 2016.Nils Bjorck, Carla P Gomes, Bart Selman, and Kilian Q Weinberger. Understanding batch normal-ization. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett(eds.),
Advances in Neural Information Processing Systems 31 , pp. 7705–7716. Curran Asso-ciates, Inc., 2018.Minhyung Cho and Jaehyung Lee. Riemannian approach to batch normalization. In I. Guyon, U. V.Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (eds.),
Advances inNeural Information Processing Systems 30 , pp. 5225–5235. Curran Associates, Inc., 2017.Sanjoy Dasgupta and Anupam Gupta. An elementary proof of a theorem of johnson and linden-strauss.
Random Structures & Algorithms , 22(1):60–65, 2003.Robert Mansel Gower, Nicolas Loizou, Xun Qian, Alibek Sailanbayev, Egor Shulgin, and PeterRicht´arik. Sgd: General analysis and improved rates. arXiv preprint arXiv:1901.09401 , 2019.Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recog-nition. In
Proceedings of the IEEE conference on computer vision and pattern recognition , pp.770–778, 2016a.Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residualnetworks. In
European conference on computer vision , pp. 630–645. Springer, 2016b.Elad Hoffer, Ron Banner, Itay Golan, and Daniel Soudry. Norm matters: efficient and accuratenormalization schemes in deep networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman,N. Cesa-Bianchi, and R. Garnett (eds.),
Advances in Neural Information Processing Systems 31 ,pp. 2164–2174. Curran Associates, Inc., 2018a.Elad Hoffer, Itay Hubara, and Daniel Soudry. Fix your classifier: the marginal value of trainingthe last weight layer. In
International Conference on Learning Representations , 2018b. URL https://openreview.net/forum?id=S1Dh8Tg0- .Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training byreducing internal covariate shift. In
International Conference on Machine Learning , pp. 448–456,2015.Jonas Kohler, Hadi Daneshmand, Aurelien Lucchi, Ming Zhou, Klaus Neymeyr, and Thomas Hof-mann. Exponential convergence rates for batch normalization: The power of length-directiondecoupling in non-convex optimization. arXiv preprint arXiv:1805.10694 , 2018.Ilya Loshchilov and Frank Hutter. SGDR: Stochastic Gradient Descent with Warm Restarts. arXive-prints , art. arXiv:1608.03983, Aug 2016. 11avid Page. How to train your resnet 6: Weight decay? URL https://myrtle.ai/how-to-train-your-resnet-6-weight-decay/ .Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito,Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation inpytorch. 2017.Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry. How does batch nor-malization help optimization? In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett (eds.),
Advances in Neural Information Processing Systems 31 , pp. 2488–2498. Curran Associates, Inc., 2018.Leslie N Smith. Cyclical learning rates for training neural networks. In , pp. 464–472. IEEE, 2017.Ilya Sutskever, James Martens, George Dahl, and Geoffrey Hinton. On the importance of ini-tialization and momentum in deep learning. In
Proceedings of the 30th International Confer-ence on International Conference on Machine Learning - Volume 28 , ICML’13, pp. III–1139–III–1147. JMLR.org, 2013. URL http://dl.acm.org/citation.cfm?id=3042817.3043064 .Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization: The missing in-gredient for fast stylization. arXiv preprint arXiv:1607.08022 , 2016.Twan van Laarhoven. L2 regularization versus batch and weight normalization. arXiv preprintarXiv:1706.05350 , 2017.David Wu. L2 regularization and batch norm. URL https://blog.janestreet.com/l2-regularization-and-batch-norm/ .Xiaoxia Wu, Rachel Ward, and L´eon Bottou. WNGrad: Learn the Learning Rate in Gradient De-scent. arXiv preprint arXiv:1803.02865 , 2018.Yuxin Wu and Kaiming He. Group normalization. In
The European Conference on Computer Vision(ECCV) , September 2018.Yang You, Igor Gitman, and Boris Ginsburg. Large Batch Training of Convolutional Networks. arXiv e-prints , art. arXiv:1708.03888, Aug 2017.Guodong Zhang, Chaoqi Wang, Bowen Xu, and Roger Grosse. Three mechanisms of weight decayregularization. In
International Conference on Learning Representations , 2019. URL https://openreview.net/forum?id=B1lz-3Rct7 .12 O MITTED P ROOFS
A.1 O
MITTED P ROOF IN S ECTION Lemma A.1 (Some Facts about Equation 1) . Suppose z , z ( z ≥ z ) are the two real roots of thethe following equation, we have x − (1 + γ − λη ) x + γ = 0 z = γ − λη + √ (1 − γ ) − γ ) λη + λ η , z = γ − λη − √ (1 − γ ) − γ ) λη + λ η z , z are real ⇐⇒ λη ≤ (1 − √ γ ) ;3. z z = γ, z + z = (1 + γ − λη ) ;4. γ ≤ z ≤ z ≤ ;5. Let t = λη − γ , we have z ≥ t ≥ − t = 1 − λη − γ .6. if we view z ( λη ) , z ( λη ) as functions of λη , then z ( λη ) is monotone decreasing, z ( η ) is monotone increasing.Proof.
4. Let f ( x ) = x − (1 + γ − λη ) x + γ , we have f (1) = f ( γ ) = λη ≥ . Note the minimumof f is taken at x = γ − λη ∈ [0 , , the both roots of f ( x ) = 0 must lie between and ,if exists.5 − z = 1 − γ + λη − (cid:112) (1 − γ ) − γ ) λη + λ η
2= (1 − γ ) 1 + t − (cid:113) − γ − γ t + t
2= (1 − γ ) 2 t + 2 γ − γ t t + (cid:113) − γ − γ t + t ) ≤ (1 − γ ) − γ t t )= t (1 + t )
6. Note that ( z − z ) = ( z + z ) − z z = (1 + γ − λη ) − γ is monotone decreasing,since z ( λη ) + z ( λη ) is constant, z ( λη ) ≥ z ( λη ) , z ( λη ) must be decreasing and z ( λη ) must be increasing.A.2 O MITTED PROOFS IN S ECTION
Proof of Lemma 2.2.
For any ( θ , η ) , we haveGD ρt ( θ , η ) = ( ρ θ − η ∇ L t ( θ ) , η ) = [Π ρ ◦ Π ρ ◦ GD t ]( θ , ηρ ) = [Π ρ ◦ Π ρ ◦ GD t ◦ Π ρ − ]( θ , η ) . Proof of Lemma 2.4.
For any ( θ , η ) , we have (cid:104) GD t ◦ Π c ◦ Π c (cid:105) ( θ , η ) = GD t ( c θ , c η ) = ( c θ − c θ ∇ L t ( c θ ) , c η ) ∗ = ( c ( θ − ∇ L t ( θ )) , c η )= (cid:104) Π c ◦ Π c ◦ GD t (cid:105) ( θ , η ) . ( ∗ = : Scale Invariance, Lemma 1.3 ) MITTED PROOFS IN S ECTION
Proof of Lemma 2.7.
For any input ( θ , η, θ (cid:48) , η (cid:48) ) , it’s easy to check both composed maps have thesame outputs on the 2,3,4th coordinates, namely ( c η, c θ , c η (cid:48) ) . For the first coordinate, we have (cid:2) GD ρ ( c θ , c η, c θ (cid:48) , c η ) (cid:3) = ρc θ + c η (cid:18) γ θ − θ (cid:48) η (cid:48) − ∇ L t ( c θ ) (cid:19) ∗ = c (cid:18) θ + η (cid:18) γ θ − θ (cid:48) η (cid:48) − ∇ L t ( θ ) (cid:19)(cid:19) = c [ GD ρ ( θ , η, θ (cid:48) , η )] . ∗ = : Scale Invariance, Lemma 1.3 Proof of Lemma 2.8.
For any input ( θ , η, θ (cid:48) , η (cid:48) ) , it’s easy to check both composed maps have thesame outputs on the 2,3,4th coordinates, namely ( η, θ , η ) . For the first coordinate, we have (cid:104)(cid:104) Π α ◦ Π α ◦ Π α ◦ GD t ◦ Π α − ◦ Π α ◦ Π α (cid:105) ( θ , η, θ (cid:48) , η ) (cid:105) = α (cid:2) GD t ( θ , α − η, α θ (cid:48) , αη ) (cid:3) = α (cid:18) θ + α − η (cid:18) γ θ − θ (cid:48) η − ∇ L t ( θ ) (cid:19)(cid:19) = (cid:0) α + γα − (cid:1) θ − η ∇ L t ( θ ) − ηγ θ (cid:48) η = ( ρ + γ ) θ − η ∇ L t ( θ ) − γ θ (cid:48) = [ GD ρt ( θ , η, θ (cid:48) , η )] A.4 O
MITTED PROOFS OF T HEOREM λ I changing each phase. Theorem A.2 (A stronger version of Theorem 2.12) . There exists a way to correct the momentumonly at the first iteration of each phase, such that the following Tapered-Exponential LR schedule(TEXP) { (cid:101) η t } with momentum factor γ and no WD, leads the same sequence networks in functionspace compared to that of Step Decay LR schedule(Definition 2.11) with momentum factor γ and phase-dependent WD λ ∗ I in phase I , where phase I lasts from iteration T I to iteration T I +1 , T = 0 . (cid:101) η t +1 = (cid:40)(cid:101) η t × ( α ∗ I − ) − if T I − + 1 ≤ t ≤ T I − , I ≥ (cid:101) η t × η ∗ I η ∗ I − × ( α ∗ I ) − ( α ∗ I − ) − if t = T I , I ≥ , (9)where α ∗ I = γ − λ ∗ I η ∗ I + (cid:113) ( γ − λ ∗ I η ∗ I ) − γ , (cid:101) η = η ( α ∗ ) − = η ∗ ( α ∗ ) − .Towards proving Theorem 2.12, we need the following lemma which holds by expanding the defi-nition, and we omit its proof. Lemma A.3 (Canonicalization) . We define the
Canonicalization map as N ( θ , η, θ (cid:48) , η (cid:48) ) = ( θ , η, θ − ηη (cid:48) ( θ − θ (cid:48) ) , η ) , and it holds that1. GD ρt ◦ N = GD ρt , ∀ ρ > , t ≥ .2. N ◦ (cid:104) Π c ◦ Π c ◦ Π c ◦ Π c (cid:105) = (cid:104) Π c ◦ Π c ◦ Π c ◦ Π c (cid:105) ◦ N , ∀ c > . Similar to the case of momentum-free SGD, we define the notion of equivalent map below
Definition A.4 (Equivalent Maps) . For two maps F and G , we say F is equivalent to G iff ∃ c > , F = (cid:104) Π c ◦ Π c ◦ Π c ◦ Π c (cid:105) ◦ G , which is also denoted by F c ∼ G .Note that for any ( θ , η, θ (cid:48) , η (cid:48) ) , [ N ( θ , η, θ (cid:48) , η (cid:48) )] = [ N ( θ , η, θ (cid:48) , η (cid:48) )] . Thus as a direct consequenceof Lemma 2.8, the following lemma holds. Lemma A.5. ∀ ρ, α > , GD ρt ◦ N α ∼ Π α − ◦ Π α − ◦ Π α − ◦ GD t ◦ Π α − ◦ Π α ◦ Π α ◦ N .Proof of Theorem2.12. Starting with initial state ( θ , η , θ − , η − ) where η − = η and a given LRschedule { η t } t ≥ , the parameters generated by GD with WD and momentum satisfies the followingrelationship: 14 θ t +1 , η t +1 , θ t , η t ) = (cid:20) Π ηt +1 ηt ◦ GD − η t λ t t (cid:21) ( θ t , η t , θ t − , η t − ) . Define b (cid:13) t = a F t = F b ◦ F b − ◦ . . . ◦ F a , for a ≤ b . By Lemma A.3 and Lemma A.5, letting α t be theroot of x − ( γ + 1 − η t − λ t − ) x + γ = 0 , we have T − (cid:13) t =0 (cid:20) Π ηt +1 ηt ◦ GD − η t λ t t (cid:21) = T − (cid:13) t =0 (cid:20) Π ηt +1 ηt ◦ GD − η t λ t t ◦ N (cid:21) T − (cid:81) i =0 α i ∼ T − (cid:13) t =0 (cid:20) Π ηt +1 ηt ◦ Π α − t +1 ◦ Π α − t +1 ◦ Π α − t +1 ◦ GD t ◦ Π α − t +1 ◦ Π α t +1 ◦ Π α t +1 ◦ N (cid:21) =Π ηTηT − ◦ Π α − T − ◦ Π α − T − ◦ Π α − T ◦ GD T − ◦ (cid:18) T − (cid:13) t =1 (cid:20) Π α − t +1 α − t ◦ H t ◦ GD t − (cid:21)(cid:19) ◦ Π α − ◦ Π α ◦ Π α ◦ N, (10)where T − (cid:81) i =0 α i ∼ is because of Lemma A.5, and H t is defined as H t = Π α t ◦ Π ηt − ηt ◦ Π α t +1 ◦ Π α t +1 ◦ N ◦ Π α − t ◦ Π α − t ◦ Π α − t ◦ Π ηtηt − . Since the canonicalization map N only changes the momentum part of the state, it’s easy to checkthat H t doesn’t touch the current parameter θ and the current LR η . Thus H t only changes themomentum part of the input state. Now we claim that H t ◦ GD t − = GD t − whenever η t = η t − .This is because when η t = η t − , α t = α t +1 , thus H t ◦ GD t − = GD t − . In detail, H t ◦ GD t − =Π α t ◦ Π α t ◦ Π α t ◦ N ◦ Π α − t ◦ Π α − t ◦ Π α − t ◦ GD t − ∗ =Π α t ◦ Π α t ◦ Π α t ◦ Π α − t ◦ Π α − t ◦ Π α − t ◦ GD t − = GD t − , where ∗ = is because GD update GD t sets η (cid:48) the same as η , and thus ensures the input of N has thesame momentum factor in buffer as its current momentum factor, which makes N an identity map.Thus we could rewrite Equation 10 with a “sloppy”version of H t , H (cid:48) t = (cid:26) H t η t (cid:54) = η t − ; Id o.w. : T − (cid:13) t =0 (cid:20) Π ηt +1 ηt ◦ GD − η t λ t t (cid:21) =Π ηTηT − ◦ Π α − T − ◦ Π α − T − ◦ Π α − T ◦ GD T − ◦ (cid:18) T − (cid:13) t =1 (cid:20) Π α − t +1 α − t ◦ H (cid:48) t ◦ GD t − (cid:21)(cid:19) ◦ Π α − ◦ Π α ◦ Π α ◦ N =Π ηTηT − ◦ Π α − T − ◦ Π α − T − ◦ Π α − T ◦ (cid:18) T − (cid:13) t =1 (cid:20) GD t ◦ Π α − t +1 α − t ◦ H (cid:48) t (cid:21)(cid:19) ◦ GD ◦ Π α − ◦ Π α ◦ Π α ◦ N, (11)Now we construct the desired sequence of parameters achieved by using the Tapered Exp LRschedule 9 and the additional one-time momentum correction per phase. Let ( (cid:101) θ , (cid:101) η , (cid:101) θ − , (cid:101) η − ) =( θ , η , θ − , η ) , and 15 (cid:101) θ , (cid:101) η , (cid:101) θ , (cid:101) η ) = (cid:104) GD ◦ Π α − ◦ Π α ◦ Π α ◦ N (cid:105) ( (cid:101) θ , (cid:101) η , (cid:101) θ − , (cid:101) η − )= (cid:104) GD ◦ Π α − ◦ Π α ◦ Π α (cid:105) ( (cid:101) θ , (cid:101) η , (cid:101) θ − , (cid:101) η − );( (cid:101) θ t +1 , (cid:101) η t +1 , (cid:101) θ t , (cid:101) η t ) = (cid:20) GD t ◦ Π α − t +1 α − t ◦ H (cid:48) t (cid:21) ( (cid:101) θ t , (cid:101) η t , (cid:101) θ t − , (cid:101) η t − ) . we claim { (cid:101) θ t } t =0 is the desired sequence of parameters. We’ve already shown that θ t ∼ (cid:101) θ t , ∀ t .Clearly { (cid:101) θ t } t =0 is generated using only vanilla GD, scaling LR and modifying the momentum partof the state. When t (cid:54) = T I for any I , η t = η t − and thus H (cid:48) t = Id . Thus the modificationon the momentum could only happen at T I ( I ≥ . Also it’s easy to check that α t = α ∗ I , if T I + 1 ≤ t ≤ T I +1 .A.5 O MITTED PROOFS OF T HEOREM
Theorem A.6.
The following two sequences of parameters , { θ t } ∞ t =0 and { (cid:101) θ t } ∞ t =0 , define the samesequence of network functions, i.e. f θ t = f (cid:101) θ t , ∀ t ∈ N , given the initial conditions, (cid:101) θ = P θ , (cid:101) θ − = P − θ − .1. θ t − θ t − η t − = γ θ t − − θ t − η t − − ∇ θ (cid:16) ( L ( θ t − ) + λ t − (cid:107) θ t − (cid:107) (cid:17) , for t = 1 , , . . . ;2. (cid:101) θ t − (cid:101) θ t − (cid:101) η t − = γ (cid:101) θ t − − (cid:101) θ t − (cid:101) η t − − ∇ θ L ( (cid:101) θ t − ) , for t = 1 , , . . . ,where (cid:101) η t = P t P t +1 η t , P t = t (cid:81) i = − α − i , ∀ t ≥ − and α t recursively defined as α t = − η t − λ t − + 1 + η t − η t − γ (1 − α − t − ) , ∀ t ≥ . (12)needs to be always positive. Here α , α − are free parameters. Different choice of α , α − wouldlead to different trajectory for { (cid:101) θ t } , but the equality that (cid:101) θ t = P t θ t is always satisfied. If the initialcondition is given via v , then it’s also free to choose η − , θ − , as long as θ − θ − η − = v . Proof of Theorem 2.13.
We will prove by induction. By assumption S ( t ) : P t θ t = (cid:101) θ t for t = − , .Now we will show that S ( t ) = ⇒ S ( t + 1) , ∀ t ≥ . θ t − θ t − η t − = γ θ t − − θ t − η t − − ∇ θ (cid:18) ( L ( θ t − ) + λ t − (cid:107) θ t − (cid:107) (cid:19) Take gradient ====== ⇒ θ t − θ t − η t − = γ θ t − − θ t − η t − − ∇ θ L ( θ t − ) + λ t − θ t − Scale Invariance ======== ⇒ θ t − θ t − η t − = γ θ t − − θ t − η t − − P t − ∇ θ L ( (cid:101) θ t − ) + λ t − θ t − Rescaling ===== ⇒ P t ( θ t − θ t − ) P t P t − η t − = γ P t − ( θ t − − θ t − ) P t − P t − η t − − ∇ θ L ( (cid:101) θ t − ) − λ t − θ t − P t − Simplfying ====== ⇒ P t θ t − α − t (cid:101) θ t − (cid:101) η t − = γ α t − (cid:101) θ t − − (cid:101) θ t − (cid:101) η t − − ∇ θ L ( (cid:101) θ t − ) − η t − λ t − P t θ t − η t − P t − P t Simplfying ====== ⇒ P t θ t − α − t (cid:101) θ t − (cid:101) η t − = γ α t − (cid:101) θ t − − (cid:101) θ t − (cid:101) η t − − ∇ θ L ( (cid:101) θ t − ) − η t − λ t − α − t (cid:101) θ t − (cid:101) η t − Simplfying ====== ⇒ P t θ t − α − t (1 − η t − λ t − ) (cid:101) θ t − (cid:101) η t − = γ α t − (cid:101) θ t − − (cid:101) θ t − (cid:101) η t − − ∇ θ L ( (cid:101) θ t − )
16o conclude that P t θ t = (cid:101) θ t , it suffices to show that the coefficients before (cid:101) θ t − is the same to thatin (2) . In other words, we need to show − α − t (1 − η t − λ t − ) (cid:101) η t − = γ (1 − α t − ) (cid:101) η t − , which is equivalent to the definition of α t , Equation 12. Lemma A.7 (Sufficient Conditions for positivity of α t ) . Let λ max = max t λ t , η max = max t η t .Define z min is the larger root of the equation x − (1 + γ − λ max η max ) x + γ = 0 . To guaranteethe existence of z max we also assume η max λ max ≤ (1 − √ γ ) . Then we have ∀ α − , α = 1 = ⇒ z min ≤ α t ≤ , ∀ t ≥ (13) Proof.
We will prove the above theorem with a strengthened induction — S ( t ) : ∀ ≤ t (cid:48) ≤ t, z min ≤ α t (cid:48) ≤ (cid:94) α − t (cid:48) − η t (cid:48) − ≤ z − min − η max . Since α = 1 , S (0) is obviously true. Now suppose S ( t ) is true for some t ∈ N , we will prove S ( t + 1) .First, since < α t ≤ , α t +1 = − η t λ t + 1 + η t η t − γ (1 − α − t ) ≤ . Again by Equation 12, we have − α t +1 = η t λ t + α − t − η t − η t γ = η t λ t + z − min − η max η t γ ≤ η t λ t + ( z − min − γ = 1 − z min , which shows α t +1 ≥ z min . Here the last step is by definition of z min .Because of α t +1 ≥ z min , we have α − t +1 − η t ≤ z − min − α t +1 η t ≤ z − min ( λ t + α − t − η t − γ ) ≤ z − min ( λ max + z − min − η max γ ) = z − min − z min η max = z − min − η max . Now we are ready to give the formal statement about the closeness of Equation 9 and the reducedLR schedule by Theorem 2.13.
Theorem A.8.
Given a Step Decay LR schedule with { T I } K − I =0 , { η ∗ I } K − I =0 , { λ ∗ I } K − I =0 , the TEXP++LR schedule in Theorem 2.13 is the following( α = α − = 1 , T = 0 ): α t = (cid:40) − η ∗ I λ ∗ I + 1 + γ (1 − α − t − ) , ∀ T I + 2 ≤ t ≤ T I +1 , I ≥ − η ∗ I λ ∗ I + 1 + η ∗ I η ∗ I − γ (1 − α − t − ) , ∀ t = T I + 1 , I ≥ P t = t (cid:89) i = − α − t ;ˆ η t = P t P t +1 η t . It’s the same as the TEXP LR schedule( { ˜ η t } ) in Theorem 2.12 throughout each phase I , in the sensethat (cid:12)(cid:12)(cid:12)(cid:12) ˆ η t − ˆ η t (cid:30) (cid:101) η t − (cid:101) η t − (cid:12)(cid:12)(cid:12)(cid:12) < λ max η max − γ (cid:18) γz min (cid:19) t − T I − ≤ λ max η max − γ (cid:20) γ (1 + λ max η max − γ ) (cid:21) ( t − T I − , ∀ T I +1 ≤ t ≤ T I +1 . z min is the larger root of x − (1 + γ − λ max η max ) x + γ = 0 . In Appendix A, we showthat z − min ≤ η max λ max − γ . When λ max η max is small compared to − γ , which is usually thecase in practice, one could approximate z min by 1. For example, when γ = 0 . , λ max = 0 . , η max = 0 . , the above upper bound becomes (cid:12)(cid:12)(cid:12)(cid:12) ˆ η t − ˆ η t (cid:30) (cid:101) η t − (cid:101) η t − (cid:12)(cid:12)(cid:12)(cid:12) ≤ . × . t − T I − . Proof of Theorem A.8.
Assuming z I and z I ( z I ≥ z I ) are the roots of Equation 1 with η = η I and λ = λ I , we have γ ≤ z I (cid:48) ≤ √ γ ≤ z min ≤ z I ≤ , ∀ I, I (cid:48) ∈ [ K − by Lemma A.1.We can rewrite the recursion in Theorem 2.13 as the following: α t = − η I λ I + 1 + γ (1 − α − t − ) = − ( z I + z I ) + z I z I α − t − . (14)In other words, we have α t − z I = z I α t − ( α t − − z I ) , t ≥ . (15)By Lemma A.7, we have α t ≥ z min , ∀ t ≥ . Thus | α t z I − | = z I α t − | α t − z I − | ≤ γz min | α t − z I − | = γz min | α t − z I − | ≤ γ (1 + λη − γ ) | α t − z I | , which means α t geometrically converges to its stable fixedpoint z I . and (cid:101) η t − (cid:101) η t = ( z I ) . Since that z min ≤ α t ≤ , z min ≤ z I ≤ , we have | α TI z I − | ≤ − z min z min = λ max η max − γ ≤ , and thus | α t z I − | ≤ λ max η max − γ ( γz min ) t − T I − ≤ , ∀ T I + 1 ≤ t ≤ T I +1 .Note that α ∗ I = z I , ˆ η t − ˆ η t = α t α t +1 By definition of TEXP and TEXP++, we have (cid:101) η t − (cid:101) η t = (cid:40) ( z I − ) if T I − + 1 ≤ t ≤ T I − η ∗ I − η ∗ I z I z I − if t = T I , I ≥ (16) ˆ η t − ˆ η t = η t − η t α t +1 α t = (cid:40) α t +1 α t if T I − + 1 ≤ t ≤ T I − η ∗ I − η ∗ I α T I +1 α T I if t = T I , I ≥ (17)Thus we have when t = T I , (cid:12)(cid:12)(cid:12)(cid:12) ˆ η t − ˆ η t (cid:30) (cid:101) η t − (cid:101) η t − (cid:12)(cid:12)(cid:12)(cid:12) ≤ (cid:12)(cid:12)(cid:12)(cid:12) α T I +1 z I α T I z I − − (cid:12)(cid:12)(cid:12)(cid:12) ≤ (cid:12)(cid:12)(cid:12)(cid:12) α T I +1 z I − (cid:12)(cid:12)(cid:12)(cid:12) + (cid:12)(cid:12)(cid:12)(cid:12) α T I z I − − (cid:12)(cid:12)(cid:12)(cid:12) + (cid:12)(cid:12)(cid:12)(cid:12) α T I +1 z I − (cid:12)(cid:12)(cid:12)(cid:12) (cid:12)(cid:12)(cid:12)(cid:12) α T I z I − − (cid:12)(cid:12)(cid:12)(cid:12) ≤ λ max η max − γ . When T I + 1 ≤ t ≤ T I +1 , we have (cid:12)(cid:12)(cid:12)(cid:12) ˆ η t − ˆ η t (cid:30) (cid:101) η t − (cid:101) η t − (cid:12)(cid:12)(cid:12)(cid:12) = (cid:12)(cid:12)(cid:12)(cid:12) α t +1 z I − α t z I − − (cid:12)(cid:12)(cid:12)(cid:12) ≤ (cid:12)(cid:12)(cid:12)(cid:12) α t +1 z I − − (cid:12)(cid:12)(cid:12)(cid:12) + (cid:12)(cid:12)(cid:12)(cid:12) α t z I − − (cid:12)(cid:12)(cid:12)(cid:12) + (cid:12)(cid:12)(cid:12)(cid:12) α t +1 z I − − (cid:12)(cid:12)(cid:12)(cid:12) (cid:12)(cid:12)(cid:12)(cid:12) α t z I − − (cid:12)(cid:12)(cid:12)(cid:12) ≤ λ max η max − γ ( γz min ) t − T I − . Thus we conclude ∀ I ∈ [ K − , T I + 1 ≤ t ≤ T I +1 , we have (cid:12)(cid:12)(cid:12)(cid:12) ˆ η t − ˆ η t (cid:30) (cid:101) η t − (cid:101) η t − (cid:12)(cid:12)(cid:12)(cid:12) ≤ λ max η max − γ (cid:18) γz min (cid:19) t − T I − ≤ λ max η max − γ · γ t − T I − (1+ λ max η max − γ ) t − T I − . MITTED P ROOFS IN S ECTION ˆ w to denote w (cid:107) w (cid:107) and ∠ uw to arccos( ˆ u (cid:62) ˆ w ) . Note that training error ≤ επ is equivalentto ∠ e w t < ε . Case 1: WD alone
Since the objective is strongly convex, it has unique argmin w ∗ . By symmetry, w ∗ = β e , for some β > . By KKT condition, we have λβ = E x ∼N (0 , (cid:20) | x | β | x | ) (cid:21) ≤ E x ∼N (0 , [ | x | ] = (cid:114) π , which implies (cid:107) w ∗ (cid:107) = O ( λ ) .By Theorem 3.1 of Gower et al. (2019), for sufficiently large t , we have E (cid:107) w t − w ∗ (cid:107) = O ( ηBλ ) .Note that ∠ e w t = ∠ w ∗ w t ≤ ∠ w ∗ w t ≤ (cid:107) w ∗ − w t (cid:107)(cid:107) w ∗ (cid:107) , we have E ( ∠ e w t ) = O ( ηλB ) , so theexpected error = E ( ∠ e w t ) /π ≤ (cid:112) E ( ∠ e w t ) /π = O ( (cid:113) ηλB ) . Case 3: Both BN and WD
We will need the following lemma when lower bounding the norm ofthe stochastic gradient.
Lemma A.9 (Concentration of Chi-Square) . Suppose X , . . . , X k i.i.d. ∼ N (0 , , thenPr (cid:34) k (cid:88) i =1 X i < kβ (cid:35) ≤ (cid:0) βe − β (cid:1) k . (18) Proof.
This Chernoff-bound based proof is a special case of Dasgupta & Gupta (2003).Pr (cid:34) k (cid:88) i =1 X i < kβ (cid:35) ≤ (cid:0) βe − β (cid:1) k = Pr (cid:34) exp (cid:32) ktβ − t k (cid:88) i =1 X i (cid:33) ≥ (cid:35) ≤ E (cid:34) exp (cid:32) ktβ − t k (cid:88) i =1 X i (cid:33)(cid:35) ( Markov Inequality )= e ktβ (1 + 2 t ) − k . (19)The last equality uses the fact that E (cid:2) tX i (cid:3) = √ − t for t < . The proof is completed by taking t = − β β . Setting for Theorem A.6:
Suppose WD factor is λ , LR is η , the width of the last layer is m ≥ ,Now the SGD updates have the form w t +1 = w t − ηB B (cid:88) b =1 ∇ (cid:18) ln(1 + exp( − x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) ) + λ (cid:107) w t (cid:107) (cid:19) =(1 − λη ) w t − ηB B (cid:88) b =1 y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) Π ⊥ w t x t,b (cid:107) w t (cid:107) , where x t,b i.i.d. ∼ N (0 , I m ) , y t,b = sign ([ x t,b ] ) , and Π ⊥ w t = I − w t w (cid:62) t (cid:107) w t (cid:107) . Proof of Theorem A.6. tep 1: Let T = ηλ − ε ) ln (cid:107) w T (cid:107) ε √ Bη √ m − , and T = 9 ln δ . Thus if we assume the trainingerror is smaller than ε from iteration T to T + T + T , then by spherical triangle inequality, ∠ w t w t (cid:48) ≤ ∠ e w t (cid:48) + ∠ e w t = 2 ε , for T ≤ t, t (cid:48) ≤ T + T + T .Now let’s define w (cid:48) t = (1 − ηλ ) w t and for any vector w , and we have the following two relation-ships:1. (cid:107) w (cid:48) t (cid:107) = (1 − ηλ ) (cid:107) w (cid:107) .2. (cid:107) w t +1 (cid:107) ≤ (cid:107) w (cid:48) t (cid:107) cos 2 ε .The second property is because by Lemma 1.3, ( w t +1 − w (cid:48) t ) ⊥ w (cid:48) t and by assumption of smallerror, ∠ w t +1 w (cid:48) t ≤ ε .Therefore (cid:107) w T + T (cid:107) (cid:107) w T (cid:107) ≤ (cid:18) − ηλ cos 2 ε (cid:19) T ≤ (cid:18) − ηλ − ε (cid:19) T ≤ (cid:0) − ( ηλ − ε ) (cid:1) T ≤ e − T ( ηλ − ε ) = η (cid:107) w T (cid:107) ε (cid:114) m − B . (20)In other word, (cid:107) w T + T (cid:107) ≤ η ε (cid:113) m − B . Since (cid:107) w T + t (cid:107) is monotone decreasing, (cid:107) w T + t (cid:107) ≤ η ε (cid:113) m − B holds for any t = T , . . . , T + T . Step 2:
We show that the norm of the stochastic gradient is lower bounded with constant probability.In other words, we want to show the norm of ξ t = (cid:80) Bb =1 y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) Π ⊥ w t x t,b (cid:107) w t (cid:107) is lowerbounded with high probability.Let Π ⊥ w t , e be the projection matrix for the orthogonal space spanned by w t and e . W.L.O.G, wecan assume the rank of Π ⊥ w t , e is 2. In case w t = e , we just exclude a random direction to make Π ⊥ w t , e rank 2. Now we have Π ⊥ w t , e x t,b are still i.i.d. multivariate gaussian random variables, for b = 1 , . . . , B , and moreover, Π ⊥ w t , e x t,b is independent to y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) . When m ≥ , wecan lower bound (cid:107) ξ t (cid:107) by dealing with (cid:107) Π ⊥ w t , e ξ t (cid:107) .It’s not hard to show that conditioned on { x t,b (cid:62) w t (cid:107) w t (cid:107) , [ x t,b ] } Bb =1 , B (cid:88) b =1 y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) Π ⊥ w t x t,b d = (cid:118)(cid:117)(cid:117)(cid:116) B (cid:88) b =1 (cid:32) y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) (cid:33) Π ⊥ w t , e x , (21)where x ∼ N ( , I m ) . We further note that (cid:107) Π ⊥ w t , e x (cid:107) ∼ χ ( m − . By Lemma A.9,Pr (cid:20) (cid:107) Π ⊥ w t , e x t (cid:107) ≥ m − (cid:21) ≥ − ( 18 e ) m − ≥ − ( 18 e ) ≥ . (22)Now we will give a high probability lower bound for (cid:80) Bb =1 (cid:18) y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) (cid:19) . Note that x (cid:62) t w t (cid:107) w t (cid:107) ∼ N (0 , , we have Pr (cid:20) | x (cid:62) t,b w t (cid:107) w t (cid:107) | < (cid:21) ≥ , (23)which implies the following, where A t,b is defined as (cid:104) | x (cid:62) t,b w t (cid:107) w t (cid:107) | < ≥ (cid:105) :20 A t,b = Pr (cid:34) (cid:107) y t,b x (cid:62) t,b w t (cid:107) w t (cid:107) y t ) (cid:107) ≥
11 + e (cid:35) ≥ . (24)Note that (cid:80) Bb =1 A t,b ≤ B , and E (cid:80) Bb =1 A t,b ≥ B , we have Pr (cid:104)(cid:80) Bb =1 A t,b < B (cid:105) ≤ . Thus,Pr B (cid:88) b =1 (cid:32) y t,b x t,b (cid:62) w t (cid:107) w t (cid:107) y t,b ) (cid:33) ≥ B e ) ≥ Pr (cid:34) B (cid:88) b =1 A t,b ≥ B (cid:35) ≥ . (25)Thus w.p. at least , equation 25 and equation 22 happen together, which implies (cid:107) ηB B (cid:88) b =1 ∇ ln(1+exp( − x (cid:62) t,b w t (cid:107) w t (cid:107) y t,b )) (cid:107) = (cid:107) ηB B (cid:88) b =1 y t,b x (cid:62) t,b w t (cid:107) w t (cid:107) y t ) Π ⊥ w t x t,b (cid:107) w t (cid:107) (cid:107) ≥ η e √ m − (cid:107) w t (cid:107) ≥ η (cid:107) w t (cid:107) (cid:114) m − B (26) Step 3 . To stay in the cone { w | ∠ we ≤ ε } , the SGD update (cid:107) w t +1 − w (cid:48) t (cid:107) = (cid:107) ηB (cid:80) Bb =1 ∇ ln(1 +exp( − x (cid:62) t,b w t (cid:107) w t (cid:107) y t,b )) (cid:107) has to be smaller than (cid:107) w t (cid:107) sin 2 ε for any t = T + T , . . . , T + T + T .However, step 1 and 2 together show that (cid:107)∇ ln(1 + exp( − x (cid:62) t w t (cid:107) w t (cid:107) y t )) (cid:107) ≥ (cid:107) w t (cid:107) ε w.p. periteration. Thus the probability that w t always stays in the cone for every t = T + T , . . . , T + T + T is less than (cid:0) (cid:1) T ≤ δ .It’s interesting that the only property of the global minimum we use is that the if both w t , w t +1 are ε − optimal, then the angle between w t and w t +1 is at most ε . Thus we indeed have proved astronger statement: At least once in every ηλ − ε ) ln (cid:107) w T (cid:107) ε √ Bη √ m − + 9 ln δ iterations, the anglebetween w t and w t +1 will be larger than (cid:15) . In other words, if the the amount of the updatestabilizes to some direction in terms of angle, then the fluctuation in terms of angle must be largerthan √ ηλ for this simple model, no matter how small the noise is.A.7 O MITTED P ROOFS IN S ECTION Lemma A.10.
Suppose loss L is scale invariant, then L is non-convex in the following two sense:1. The domain is non-convex: scale invariant loss can’t be defined at origin;2. There exists no ball containing origin such that the loss is locally convex, unless the loss isconstant function.Proof. Suppose L ( θ ∗ ) = sup θ ∈ B L ( θ ) . W.L.O.G, we assume (cid:107) θ ∗ (cid:107) < . By convexity, every linesegment passing θ ∗ must have constant loss, which implies the loss is constant over set B − { c θ ∗ (cid:107) θ ∗ (cid:107) |− ≤ c ≤ } . Applying the above argument on any other maximum point θ (cid:48) implies the loss isconstant over B − { } . Theorem A.11.
Suppose the momentum factor γ = 0 , LR η t = η is constant, and the loss function L is lower bounded. If ∃ c > and T ≥ such that ∀ t ≥ T , f ( θ t +1 ) − f ( θ t ) ≤ − cη (cid:107)∇ L ( θ t ) (cid:107) ,then lim t →∞ (cid:107) θ t (cid:107) = 0 . Proof in Item 3.
By Lemma 1.3 and the update rule of GD with WD, we have (cid:107) θ t (cid:107) = (cid:107) (1 − λη ) (cid:107) θ (cid:107) t − + η ∇ L ( θ t − ) (cid:107) = (1 − λη ) (cid:107) θ t − (cid:107) + η (cid:107)∇ L ( θ t − ) (cid:107) , which implies (cid:107) θ t (cid:107) = t − (cid:88) i = T (1 − λη ) t − i − η (cid:107)∇ L ( θ t − ) (cid:107) + (1 − λη ) t − T ) (cid:107) θ T (cid:107) . T (cid:48) > T , T (cid:48) (cid:88) t = T (cid:107) θ t (cid:107) ≤ − (1 − λη ) T (cid:48) − (cid:88) t = T (cid:107)∇ L ( θ t ) (cid:107) + (cid:107) θ T (cid:107) ≤ λη T (cid:48) − (cid:88) t = T (cid:107)∇ L ( θ t ) (cid:107) + (cid:107) θ T (cid:107) . Note that by assumption we have (cid:80) T (cid:48) − t = T (cid:107)∇ L ( θ t ) (cid:107) = cη f ( θ T ) − f ( θ T (cid:48) ) .As a conclusion, we have (cid:80) ∞ t = T (cid:107) θ t (cid:107) ≤ f ( θ T ) − min θ f ( θ ) cη λ + (cid:107) θ T (cid:107) λη , which implies lim t →∞ (cid:107) θ t (cid:107) =0 . B O
THER R ESULTS
Now we rigorously analyze norm growth in this algorithm. This greatly extends previous analysesof effect of normalization schemes (Wu et al., 2018; Arora et al., 2018) for vanilla SGD.
Theorem B.1.
Under the update rule 1.2 with λ t = 0 , the norm of scale invariant parameter θ t satisfies the following property: • Almost Monotone Increasing: (cid:107) θ t +1 (cid:107) − (cid:107) θ t (cid:107) ≥ − γ t +1 η t η ( (cid:107) θ (cid:107) − (cid:107) θ − (cid:107) ) . • Assuming η t = η is a constant, then (cid:107) θ t +1 (cid:107) = t (cid:88) i =0 − γ t − i +1 − γ (cid:0) (cid:107) θ i − θ i +1 (cid:107) + γ (cid:107) θ i − − θ i (cid:107) (cid:1) − γ − γ t +1 − γ ( (cid:107) θ (cid:107) −(cid:107) θ − (cid:107) ) Proof.
Let’s use R t , D t , C t to denote (cid:107) θ t (cid:107) , (cid:107) θ t +1 − θ t (cid:107) , θ (cid:62) t ( θ t +1 − θ t ) respectively.The only property we will use about loss is ∇ θ L (cid:62) t θ t = 0 .Expanding the square of (cid:107) θ t +1 (cid:107) = (cid:107) ( θ t +1 − θ t ) + θ t (cid:107) , we have ∀ t ≥ − S ( t ) : R t +1 − R t = D t + 2 C t . We also have C t η t = θ (cid:62) t θ t +1 − θ t η t = θ (cid:62) t ( γ θ t − θ t − η t − − λ t θ t ) = γη t − ( D t + C t − ) − λ t R t , namely, ∀ t ≥ P ( t ) : C t η t − γD t η t − = γη t − C t − − λ t R t . Simplify S ( t ) η t − γS ( t − η t − + P ( t ) , we have R t +1 − R t η t − γ R t − R t − η t − = D t η t + γ D t − η t − − λ t R t . (27)When λ t = 0 , we have R t +1 − R t η t = γ t +1 R − R − η − + t (cid:88) i =0 γ t − i ( D i η i + γ D i − η i − ) ≥ γ t +1 R − R − η . Further if η t = η is a constant, we have R t +1 = R + t (cid:88) i =0 − γ t − i +1 − γ ( D i + γD i − ) − γ − γ t +1 − γ ( R − R − ) , R t +1 = R + t (cid:88) i =0 D i . For general deep nets, we have the following result, suggesting that the mean square of the updateare constant compared to the mean square of the norm. The constant is mainly determined by ηλ ,explaining why the usage of weight decay prevents the parameters to converge in direction. Theorem B.2.
For SGD with constant LR η , weight decay λ and momentum γ , when the limits R ∞ = lim T →∞ T (cid:80) T − t =0 (cid:107) w t (cid:107) , D ∞ = lim T →∞ T (cid:80) T − t =0 (cid:107) w t +1 − w t (cid:107) exist, we have D ∞ = 2 ηλ γ R ∞ . Proof of Theorem B.2.
Take average of Equation 27 over t , when the limits R ∞ =lim T →∞ T (cid:80) T − t =0 (cid:107) w t (cid:107) , D ∞ = lim T →∞ T (cid:80) T − t =0 (cid:107) w t +1 − w t (cid:107) exists, we have γη D ∞ = 2 λR ∞ . C S
CALE I NVARIANCE IN M ODERN N ETWORK A RCHITECTURES
In this section, we will discuss how Normalization layers make the output of the network scale-invariant to its parameters. Viewing a neural network as a DAG, we give a sufficient condition forthe scale invariance which could be checked easily by topological order, and apply this on sev-eral standard network architectures such as Fully Connected(FC) Networks, Plain CNN, ResNet(Heet al., 2016a), and PreResNet(He et al., 2016b). For simplicity, we restrict our discussions amongnetworks with ReLU activation only. Throughout this section, we assume the linear layers andthe bias after last normalization layer are fixed to its random initialization, which doesn’t harm theperformance of the network empirically(Hoffer et al., 2018b).C.1 N
OTATIONS
Definition C.1 (Degree of Homogeneity) . Suppose k is an integer and θ is all the parameters ofthe network, then f is said to be homogeneous of degree k , or k -homogeneous, if ∀ c > , f ( c θ ) = c k f ( θ ) . The output of f can be multi-dimensional. Specifically, scale invariance means degree ofhomogeneity is 0.Suppose the network only contains following modules, and we list the degree of homogeneity ofthese basic modules, given the degree of homogeneity of its input.(I) Input(L) Linear Layer, e.g. Convolutional Layer or Fully Connected Layer(B) Bias Layer(Adding Trainable Bias to the output of the previous layer)(+) Addition Layer (adding the outputs of two layers with the same dimension .)(N) Normalization Layer without affine transformation(including BN, GN, LN, IN etc.)(NA) Normalization Layer with affine transformation (Page) had a similar argument for this phenomenon by connecting this to the LARS(You et al., 2017),though it’s not rigorous in the way it deals with momentum and equilibrium of norm. Addition Layer(+) is mainly used in ResNet and other similar architectures. In this section, we also use itas an alternative definition of Bias Layer(B). See Figure 7
Table 1:
Table showing how degree of homogeneity of the output of basic modules depends on the degree ofhomogeneity of the input. For the row of the
Input , entry ‘-’ means the input of the network (I) doesn’t haveany extra input, entry ‘1’ of Bias Layer means if the input is 1-homogeneous then the output is 1- homogeneous.‘ ( x, x ) ’ for ‘+’ means if the inputs of Addition Layer have the same degree of homogeneity, the output has thesame degree of homogeneity. ReLU, Pooling( and other fixed linear maps) are ignored because they keep thedegree of homogeneity and can be omitted when creating the DAG in Theorem C.3. Remark C.2.
For the purpose of deciding the degree of homogeneity of a network, there’s nodifference among convolutional layers, fully connected layer and the diagonal linear layer in theaffine transformation of Normalization layer, since they’re all linear and the degree of homogeneityis increased by 1 after applying them.On the other hand, BN and IN has some benefit which GN and LN doesn’t have, namely the biasterm (per channel) immediately before BN or IN has zero effect on the network output and thus canbe removed. (See Figure 15)We also demonstrate the homogeneity of the output of the modules via the following figures, whichwill be reused to later to define network architectures. (a)
Input(I) (b)
Linear(L) (c)
Addition(+) (d)
Normalization(N) (e)
Bias(B) (f)
Alternative Definition of Bias(B) (g)
Normalization withAffine(NA) (h)
Definition of Normalization with Affine(NA)
Figure 7:
Degree of homogeneity of the output of basic modules given degree of homogeneity of the input. heorem C.3. For a network only consisting of modules defined above and ReLU activation, wecan view it as a
Directed acyclic graph and check its scale invariance by the following algorithm.
Input :
DAG G = ( V, E ) translated from a neural network; the module type of each node v i ∈ V . for v in topological order of G do Compute the degree of homogeneity of v using Table 1; if v is not homogeneous then return False ; if v ouptut is 0-homogeneous then return True ; else return False .C.2 N
ETWORKS WITHOUT A FFINE T RANSFORMATION AND B IAS
We start with the simple cases where all bias term(including that of linear layer and normalizationlayer) and the scaling term of normalization layer are fixed to be 0 and 1 element-wise respectively,which means the bias and the scaling could be dropped from the network structure. We empiricallyfind this doesn’t affect the performance of network in a noticeable way. We will discuss the full casein the next subsection.
Plain CNN/FC networks:
See Figure 8.
Figure 8:
Degree of homogeneity for all modules in vanilla CNNs/FC networks.
Figure 9:
An example of the full network structure of ResNet/PreResNet represented by composite modules de-fined in Figure 10,11,13,14, where ‘S’ denotes the starting part of the network, ‘Block’ denotes a normal blockwith residual link, ‘D-Block’ denotes the block with downsampling, and ‘N’ denotes the normalization layerdefined previously. Integer x ∈ { , , } depends on the type of network. See details in Figure 10,11,13,14. ResNet:
See Figure 10. To ensure the scaling invariance, we add an additional normalizaiton layerin the shortcut after downsampling. This implementation is sometimes used in practice and doesn’taffect the performance in a noticeable way.
Preactivation ResNet:
See Figure 11. Preactivation means to change the order between convolu-tional layer and normalization layer. For similar reason, we add an additional normalizaiton layer inthe shortcut before downsampling. 25 a) The starting part of ResNet (b)
A block of ResNet (c)
A block of ResNet with downsampling
Figure 10:
Degree of homogeneity for all modules in ResNet without affine transformation in normalizationlayer. The last normalization layer is omitted.
C.3 N
ETWORKS WITH A FFINE T RANSFORMATION
Now we discuss the full case where the affine transformation part of normalization layer is trainable.Due to the reason that the bias of linear layer (before BN) has 0 gradient as we mentioned in C.2,the bias term is usually dropped from network architecture in practice to save memory and accel-erate training( even with other normalization methods)(See PyTorch Implementation (Paszke et al.,2017)). However, when LN or GN is used, and the bias term of linear layer is trainable, the networkcould be scale variant (See Figure 15).
Plain CNN/FC networks:
See Figure 12.
ResNet:
See Figure 13. To ensure the scaling invariance, we add an additional normalizaiton layerin the shortcut after downsampling. This implementation is sometimes used in practice and doesn’taffect the performance in a noticeable way.
Preactivation ResNet:
See Figure 14. Preactivation means to change the order between convolu-tional layer and normalization layer. For similar reason, we add an additional normalizaiton layer inthe shortcut before downsampling. 26 a) The starting part of PreResNet (b)
A block of PreResNet (c)
A block of PreResNet with downsampling
Figure 11:
Degree of homogeneity for all modules in ResNet without affine transformation in normalizationlayer. The last normalization layer is omitted.
Figure 12:
Degree of homogeneity for all modules in vanilla CNNs/FC networks. a) The starting part of ResNet (b)
A block of ResNet (c)
A block of ResNet with downsampling
Figure 13:
Degree of homogeneity for all modules in ResNet with trainable affine transformation. The lastnormalization layer is omitted. a) The starting part of PreResNet (b)
A block of PreResNet (c)
A block of PreResNet with downsampling
Figure 14:
Degree of homogeneity for all modules in PreResNet with trainable affine transformation. The lastnormalization layer is omitted.
Figure 15:
The network can be not scale variant if the GN or IN is used and the bias of linear layer is trainable.The red ‘F’ means the Algorithm 1 will return
False here.here.