Dynamically Stable Infinite-Width Limits of Neural Classifiers
DDynamically Stable Infinite-Width Limits of NeuralClassifiers
Eugene A. Golikov
Neural Networks and Deep Learning lab.Moscow Institute of Physics and TechnologyMoscow, Russia [email protected]
Abstract
Recent research has been focused on two different approaches to studying neuralnetworks training in the limit of infinite width (1) a mean-field (MF) and (2) aconstant neural tangent kernel (NTK) approximations. These two approaches havedifferent scaling of hyperparameters with a width of a network layer and as a resultdifferent infinite-width limit models. We propose a general framework to studyhow the limit behavior of neural models depends on the scaling of hyperparameterswith a network width. Our framework allows us to derive scaling for existing MFand NTK limits, as well as an uncountable number of other scalings that lead to adynamically stable limit behavior of corresponding models. However, only a finitenumber of distinct limit models are induced by these scalings. Each distinct limitmodel corresponds to a unique combination of such properties as boundednessof logits and tangent kernels at initialization or stationarity of tangent kernels.Existing MF and NTK limit models, as well as one novel limit model satisfy mostof the properties demonstrated by finite-width models. We also propose a novelinitialization-corrected mean-field limit that satisfies all properties noted above, andits corresponding model is a simple modification for a finite-width model. Sourcecode to reproduce all the reported results is available on GitHub. For a couple of decades neural networks have proved to be useful in a variety of applications. However,their theoretical understanding is still lacking. Several recent works have tried to simplify the object ofstudy by approximating a training dynamics of a finite-width neural network with its limit counterpartin the limit of a large number of hidden units; we refer it as an "infinite-width" limit. The exact typeof the limit training dynamics depends on how hyperparameters of the training dynamics scale withwidth. In particular, two different types of limit models have been already extensively discussed inthe literature: an NTK model [1] and a mean-field limit model [2, 3, 4, 5, 6, 7]. A recent work [8]attempted to provide a link between these two different types of limit models by building a frameworkfor choosing a scaling of hyperparameters that lead to a "well-defined" limit model. Our work is thenext step in this direction. Our contributions are following.1. We develop a framework for reasoning about scaling of hyperparameters, which allows oneto infer scaling parameters that allow for a dynamically stable model evolution in the limitof infinite width. This framework allows us to derive both mean-field and NTK limits thathave been extensively studied in the literature, as well as the "intermediate limit" introducedin [8]. https://github.com/deepmipt/research/tree/master/Infinite_Width_Limits_of_Neural_Classifiers Preprint. Under review. a r X i v : . [ c s . L G ] J un q q σ − − dyn a m i ca ll y s t a b l e m od e l e vo l u ti on evolvingkernelsfinite logitsat initializationsym-defaultMFfinite tangent kernelsat initialization NTKlogits and kernelsare of the same orderat initialization Figure 1:
A diagram on the left specifies several properties demonstrated by finite-width mod-els. As plots on the right demonstrate, our novel IC-MF limit model satisfy all of these prop-erties, while MF and NTK limit models, as well as sym-default limit model presented in thepaper violate some of them.
Left:
A band of scaling exponents ( q σ , ˜ q ) that lead to dynamicallystable model evolutions in the limit of infinite width, as well as dashed lines of special properties thatcorresponding limits satisfy. Three colored points correspond to limit models that satisfy most ofthese properties. Right:
Training dynamics of three models that correspond to color points on the leftplot, as well as of initialization-corrected mean-field model (IC-MF), which does not correspond toany point of the left plot. These models are results of scaling of a reference model of width d = 2 (black line) up to width d = 2 (colored lines). Solid lines correspond to the test set, while dashedlines are for the train set. See SM E for details.2. Our framework demonstrates that there are only 13 distinct stable model evolution equationsin the limit of infinite width that can be induced by scaling hyperparameters of a finite-widthmodel. Each distinct limit model corresponds to a region (two-, one-, or zero-dimensional)of a green band of the Figure 1, left.3. We consider a list of properties that are statisfied by the evolution of finite-width models,but not generally are for its infinite-width limits. We demonstrate that mean-field and NTKlimit models, as well as "sym-default" limit model which was not discussed in the literaturepreviously, are special in the sense that they satisfy most of these properties among all limitmodels induced by hyperparameter scalings. We propose a model modification that allowsfor all of these properties in the limit of infinite width and call the corresponding limit"initialization-corrected mean-field limit (IC-MF)".4. We discuss the ability of limit models to approximate the training dynamics of finite-widthones. We show that our proposed IC-MF limiting model is the best among all other possiblelimit models. Consider a one hidden layer network: f ( x ; a , W ) = a T φ ( W T x ) = d (cid:88) r =1 a r φ ( w Tr x ) , (1)where x ∈ R d x , W = [ w , . . . , w d ] ∈ R d x × d , and a = [ a , . . . , a d ] T ∈ R d . We assume anonlinearity to be real analytic and asymptotically linear: φ ( z ) = Θ z →∞ ( z ) . Such a nonlinearitycan be, e.g. "leaky softplus": φ ( z ) = ln(1 + e z ) − α ln(1 + e − z ) for α > . This is a technicalassumption introduced to ease proofs. For simplicity, we assume the loss function (cid:96) ( y, z ) to be thestandard binary cross-entropy loss: (cid:96) ( y, z ) = ln(1 + e − yz ) , where labels y ∈ {− , } . The datadistribution loss is defined as L ( a , W ) = E x ,y ∼D (cid:96) ( y, f ( x ; a , W )) .2eights are initialized with isotropic gaussians with zero means: w (0) r ∼ N (0 , σ w I ) , a (0) r ∼N (0 , σ a ) ∀ r = 1 . . . d . The evolution of weights is driven by the stochastic gradient descent (SGD): ∆ θ ( k ) = θ ( k +1) − θ ( k ) = − η θ ∂(cid:96) ( y ( k ) θ , f ( x ( k ) θ ; a , W )) ∂θ , ( x ( k ) θ , y ( k ) θ ) ∼ D , (2)where θ is either a or W . We assume that gradients for a and W are estimated using independentdata samples ( x ( k ) a , y ( k ) a ) and ( x ( k ) w , y ( k ) w ) . We introduce this assumption for the ease of proofs. Notethat corresponding stochastic gradients still give unbiased estimates for true gradients. Define: ˆ a ( k ) r = a ( k ) r σ a , ˆ w ( k ) r = w ( k ) r σ w , ˆ η a = η a σ a , ˆ η w = η w σ w . (3)Then the dynamics transforms to: ∆ˆ θ ( k ) r = ˆ η θ ∂(cid:96) ( y ( k ) θ , f ( x ( k ) θ ; σ a ˆ a , σ w ˆ W )) ∂ ˆ θ r , ( x ( k ) θ , y ( k ) θ ) ∼ D , (4)while scaled initial conditions become: ˆ a (0) r ∼ N (0 , , ˆ w (0) r ∼ N (0 , I ) ∀ r = 1 . . . d .By expanding gradients, we get the following: ∆ˆ a ( k ) r = − ˆ η a σ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( σ w ˆ w ( k ) ,Tr x ( k ) a ) , ˆ a (0) r ∼ N (0 , , (5) ∆ ˆ w ( k ) r = − ˆ η w σ a σ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) ˆ a ( k ) r φ (cid:48) ( . . . ) x ( k ) w , ˆ w (0) r ∼ N (0 , I ) , (6) ∇ ( k ) f d (cid:96) ( x , y ) = ∂(cid:96) ( y, z ) ∂z (cid:12)(cid:12)(cid:12)(cid:12) z = f ( k ) d ( x ) = − y f ( k ) d ( x ) y ) , f ( k ) d ( x ) = σ a d (cid:88) r =1 ˆ a ( k ) r φ ( σ w ˆ w ( k ) ,Tr x ) . Without loss of generality assume σ w = 1 (we can rescale inputs x otherwise). We shall omit asubscript of σ a from now on. Assume hyperparameters that drive the dynamics obey power-lawdependence on d : σ ( d ) = σ ∗ d q σ , ˆ η a ( d ) = ˆ η ∗ a d ˜ q a , ˆ η w ( d ) = ˆ η ∗ w d ˜ q w . This assumption is quite natural. Indeed, for He initialization [9] commonly used in practice σ ∝ d − / , while we keep learning rates in the original parameterization constant while chaningthe width by default: η a/w = const , which implies ˆ η a ∝ d and ˆ η w ∝ d . On the other hand, NTKscaling [1, 10] requires scaled learning rates to be constants: ˆ η a/w ∝ d . Here and then we write" a/w " meaning " a or w ".Since φ is smooth, we have: ∆ f ( k ) d ( x ) = f ( k +1) d ( x ) − f ( k ) d ( x ) = d (cid:88) r =1 ∂f d ( x ) ∂ ˆ θ r (cid:12)(cid:12)(cid:12)(cid:12) ˆ θ r =ˆ θ ( k ) r ∆ˆ θ ( k ) r + O ˆ η ∗ a/w → (ˆ η ∗ a ˆ η ∗ w + ˆ η ∗ , w ) == − ˆ η ∗ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) K ( k ) a,d ( x , x ( k ) a ) − ˆ η ∗ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) K ( k ) w,d ( x , x ( k ) w ) + O (ˆ η ∗ a ˆ η ∗ w + ˆ η ∗ , w ) , (7)where we have defined kernels: K ( k ) a,d ( x , x (cid:48) ) = d ˜ q a σ d (cid:88) r =1 φ ( ˆ w ( k ) ,Tr x ) φ ( ˆ w ( k ) ,Tr x (cid:48) ) , (8) K ( k ) w,d ( x , x (cid:48) ) = d ˜ q w σ d (cid:88) r =1 | ˆ a ( k ) r | φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) x T x (cid:48) . (9)Define the following quantity: ∆ f ( k ) , (cid:48) d,a/w ( x ) = ∂ ∆ f ( k ) d ( x ) ∂ ˆ η ∗ a/w (cid:12)(cid:12)(cid:12)(cid:12)(cid:12) ˆ η ∗ a =0ˆ η ∗ w =0 = −∇ ( k ) f d (cid:96) ( x ( k ) a/w , y ( k ) a/w ) K ( k ) a/w,d ( x , x ( k ) a/w ) . (10)3e use this quantity to rewrite the model increment: ∆ f ( k ) d ( x ) = ˆ η ∗ a ∆ f ( k ) , (cid:48) d,a ( x ) + ˆ η ∗ w ∆ f ( k ) , (cid:48) d,w ( x ) + O (ˆ η ∗ a ˆ η ∗ w + ˆ η ∗ , w ) . (11)Define p ( k ) err,d = P ( y, x ,y (: k − a , x (: k − a ,y (: k − w , x (: k − w ) ∼D k − { yf ( k ) d ( x ) < } — the probability ofgiving a wrong answer on the step k . Let k term,d ∈ N ∪ { + ∞} be a maximal k such that ∀ k (cid:48) < kp ( k (cid:48) ) err,d > . Generally, k term,d depends on hyperparameters, as well as on the data distribution D .Scaling exponents ( q σ , ˜ q a , ˜ q w ) together with proportionality factors ( σ ∗ , ˆ η ∗ a , ˆ η ∗ w ) define a limit model f ( k ) ∞ ( x ) = lim d →∞ f ( k ) d ( x ) . We call a model "dynamically stable in the limit of large width" if itsatisfies the following condition: Condition 1. ∃ k balance ∈ N : ∀ k ∈ [ k balance , k term, ∞ ) ∩ N y ( k ) a f ( k ) ∞ ( x ( k ) a ) < and y ( k ) w f ( k ) ∞ ( x ( k ) w ) < imply ∆ f ( k ) , (cid:48) d,a/w ( x ) = Θ d →∞ ( f ( k balance ) d ( x )) x -a.e. ( y (: k ) a/w , x (: k ) a/w ) -a.s. Roughly speaking, this condition states that the change of logits after a single step (when learningrates are sufficiently small) is comparable to logits themselves. This means that the model learns.Note that this condition is weaker than the one used in [8], because it allows logits to vanish or divergewith width. Such situations are fine, because only logit signs matter for the binary classification.This condition puts a constraint on exponents ( q σ , ˜ q a , ˜ q w ) ; this constraint generally depends on thetrain data distribution D and on proportionality factors σ ∗ , ˆ η ∗ a/w . In order to obtain a data-independenthyperparameter-independent constraint, we need the condition above to hold for any value of k term, ∞ and any values of σ ∗ , ˆ η ∗ a/w . Without loss of generality we can assume k term, ∞ to be infinite, whichgives the following condition: Condition 2 (Dynamical stability) . Given k term, ∞ = + ∞ , ∃ k balance ∈ N : ∀ σ ∗ > ∀ ˆ η ∗ a/w > ∀ k ≥ k balance y ( k ) a f ( k ) ∞ ( x ( k ) a ) < and y ( k ) w f ( k ) ∞ ( x ( k ) w ) < imply ∆ f ( k ) , (cid:48) d,a/w ( x ) =Θ d →∞ ( f ( k balance ) d ( x )) x -a.e. ( y (: k ) a/w , x (: k ) a/w ) -a.s. For simplicity assume ˜ q a = ˜ q w = ˜ q . We prove the following in SM A.1: Proposition 1.
Suppose ˜ q a = ˜ q w = ˜ q and D is a continuous distribution. Then Condition 2 requires q σ + ˜ q ∈ [ − / , to hold. This statement gives a necessary condition for growth rates of σ and ˆ η to lead to a well-defined limitmodel evolution. This condition corresponds to a band in ( q σ , ˜ q ) -plane: see Figure 1, left. We refer itas a "band of dynamical stability".Each point of this band corresponds to a dynamical stable limit model evolution. We present severalconditions that separate the dynamical stability band into regions. We then show that each regioncorresponds to a single limit model evolution. Condition 3.
Following conditions separate the band of dynamical stability (Figure 1, left):1. A limit model at initialization is finite: f (0) d ( x ) = Θ d →∞ (1) x -a.e.2. Tangent kernels at initialization are finite: K (0) d,a/w ( x , x (cid:48) ) = Θ d →∞ (1) ( x , x (cid:48) ) -a.e.3. Tangent kernels and a limit model are of the same order at initialization: K (0) d,a/w ( x , x (cid:48) ) =Θ d →∞ ( f (0) d ( x )) ( x , x (cid:48) ) -a.e.4. Tangent kernels start to evolve: ∆ K (0) , (cid:48) d,wa/w ( x , x (cid:48) ) = Θ d →∞ ( K (0) d,w ( x , x (cid:48) )) ( x , x (cid:48) ) -a.e. and ∆ K (0) , (cid:48) d,aw ( x , x (cid:48) ) = Θ d →∞ ( K (0) d,a ( x , x (cid:48) )) ( x , x (cid:48) ) -a.e. Here kernel increments ∆ K ( k ) , (cid:48) d,wa/w are defined similarly to model increments ∆ f ( k ) , (cid:48) d,a/w : ∆ K ( k ) , (cid:48) d,wa/w ( x , x (cid:48) ) = ∂ ( K ( k +1) d,w ( x , x (cid:48) ) − K ( k ) d,w ( x , x (cid:48) )) ∂ ˆ η ∗ a/w (cid:12)(cid:12)(cid:12)(cid:12)(cid:12) ˆ η ∗ a =0ˆ η ∗ w =0 . (12)4he other kernel increment ∆ K ( k ) , (cid:48) d,aw is defined similarly; see SM A.2 for explicit formulae.We prove the following in SM A.2: Proposition 2 (Separating conditions) . Given Condition 2, Condition 3 reads as, point by point:1. A limit model at initialization is finite: q σ + 1 / .2. Tangent kernels at initialization are finite: q σ + ˜ q + 1 = 0 .3. Tangent kernels and a limit model are of the same order at initialization: q σ + ˜ q + 1 / .4. Tangent kernels start to evolve: q σ + ˜ q = 0 . We have also checked this Proposition numerically for limit models discussed below: see Figure 1,right. Each condition corresponds to a straight line in the ( q σ , ˜ q ) -plane: see Figure 1, left. Thesefour lines divide the well-definiteness band into 13 regions: three are two-dimensional, seven areone-dimensional, and three are zero-dimensional. In SM B we show that each region corresponds toa single distinct limit model evolution; we also list corresponding evolution equations. Note that asegment (a one-dimensional region) that corresponds to the Condition 3-2 exactly coincides with afamily of "intermediate scalings" introduced in [8]. Note that a typical finite-width model satisfies all four statements of Condition 3 (if we exclude theword "limit" from them). Indeed, neural nets are typically initialized with He initialization [9] thatguarantees finite f (0) d even for large width d . Since learning rates of finite nets are finite, the tangentkernels are finite as well. Nevertheless, a neural tangent kernel of a typical finite-width networkevolves significantly: [11] have shown that freezing NTK of practical convolutional nets sufficientlyreduces their generalization ability; [12] also noticed that evolution of NTK is sufficient for goodperformance.Consequently, if we want a limit model to capture the dynamics of a finite-width net, we have tosatisfy all four statements of Condition 3. However, as one can see from the Figure 1, we cannotsatisfy all of them simultaneously. We say that one limit model captures the behavior of a finite-widthone better than the other, if all statements of Conditions 3 satisfied by the latter are satisfied by theformer too. If we say that the former dominates the latter in this case then one can easily notice thatthere are only three "non-dominated" limit models which we discuss in the upcoming section. Afterthat, we will introduce a model modification that allows for a limit satisfying all four statements. Obviously, the three "non-dominated" limit models are exactly three zero-dimensional regions(points) in Figure 1, left. First suppose statements 1, 2 and 3 hold, hence tangent kernels are constantthroughout training (see Figure 1, right). A corresponding point q σ = − / , ˜ q = 0 reads as σ ∝ d − / and ˆ η = const , which is the case considered in the seminal paper on NTK [1]. The limitdynamics is then given as (see SM B.1.1 and SM B for the general derivation): f ( k +1) ntk , ∞ ( x ) = f ( k ) ntk , ∞ ( x ) − ˆ η ∗ a ∇ ( k ) f ntk (cid:96) ( x ( k ) a , y ( k ) a ) K (0) a, ∞ ( x , x ( k ) a ) − ˆ η ∗ w ∇ ( k ) f ntk (cid:96) ( x ( k ) w , y ( k ) w ) K (0) w, ∞ ( x , x ( k ) w ) ,f (0) ntk , ∞ ( x ) ∼ N (0 , σ ∗ , σ (0) , ( x )) , (13)where ( x ( k ) a/w , y ( k ) a/w ) ∼ D and limit tangent kernels K (0) a/w, ∞ and standard deviations at the initializa-tion σ (0) ( x ) can be calculated along the same lines as in [10].Next, suppose statements 2 and 4 hold. In this case K ( k ) ∞ does not coincide with K (0) ∞ (see Figure1, right), hence the dynamics analogous to (13) is not closed. However, the limit dynamics canbe expressed as an evolution of a weight-space measure (see [4, 6] for a similar dynamics for thegradient flow, SM B.2.1 and SM B for the general derivation): µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) mf ) , µ (0) ∞ = N (0 , I d x ) , (14)5 ( k ) mf , ∞ ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) , (15)where the vector field ∆ θ ( k ) mf is defined as follows: ∆ θ ( k ) mf (ˆ a, ˆ w ) = − [ ∇ ( k ) f mf (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ∇ ( k ) f mf (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (16)where we write " [ u , v ] " meaning a concatenation of two row vectors u and v . Here we have q σ = − , ˜ q = 1 , hence σ ∝ d − and ˆ η ∝ d ; this hyperparameter scaling were used in [4, 6]. Note thatsince a measure at the initialization µ (0) ∞ has a zero mean, a limit model vanishes at the initialization f (0) mf , ∞ = 0 (see Figure 1, right) thus violating statements 1 and 3 of Condition 3.Finally, consider a point for which statements 1 and 4 hold: q σ = − / , ˜ q = 1 / . This situationis very similar to what we call "default" scaling. Consider He initialization [9], typically used inpractice: σ a ∝ d − / and σ w ∝ d − / x . Assume learning rates (in original parameterization) arenot modified with width: η a = const and η w = const . This implies ˆ η a ∝ d and ˆ η w ∝ , or ˜ q a = 1 and ˜ q w = 0 . We refer the scaling q σ = − / , ˜ q a = 1 and ˜ q w = 0 as "default", and the scaling q σ = − / , ˜ q = 1 / as "sym-default". A limit model evolution for the sym-default scaling looks asfollows (see SM B.2.2 for an equivalent formulation and SM B for the general derivation): µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) sym-def ) , µ (0) ∞ = N (0 , I d x ) , (17) f (0) sym-def , ∞ ( x ) ∼ N (0 , σ ∗ , σ (0) , ( x )) , z ( k ) sym-def , ∞ ( x ) = (cid:20)(cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) > (cid:21) , (18)where the vector field ∆ θ ( k ) sym-def is defined similarly to the MF case (16): ∆ θ ( k ) sym-def (ˆ a, ˆ w ) = − [ ∇ ( k ) f sym-def (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ∇ ( k ) f sym-def (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , As we show in SM C, the default scaling leads to an almost similar limit dynamics as thesym-default scaling. The quantity z ( k ) sym-def , ∞ should be perceived as a sign of f ( k ) sym-def , ∞ = σ ∗ lim d →∞ (cid:16) d q σ +1 (cid:82) ˆ aφ ( ˆ w T x ) µ ( k ) d ( d ˆ a, d ˆ w ) (cid:17) . The reason why we have to switch from logitsto their signs is that the limit model diverges for k ≥ : lim d →∞ f ( k ) d ( x ) = ∞ . Nevertheless thegradient of the cross-entropy loss is well-defined even for infinite logits: it just degenerates into thegradient of a hinge-type loss: lim f → + ∞× z ∂(cid:96) ( y,f ) ∂f = − y [ yz < . For this reason, we redefine theloss gradient for k ≥ in terms of logit signs: ∇ ( k ) f sym-def (cid:96) ( x , y ) = − y [ yz ( k ) sym-def , ∞ ( x ) < . Note thatbesides of the fact that logits diverge in the limit of large width, the measure in the parameter space µ ( k ) d stays well-defined. Here we propose a dynamics that satisfy all four statements of Condition 3. We then show how tomodify the network training for the finite width in order to ensure that in the limit of the infinite widthits training dynamics converge to the proposed limit one. Consider the following: µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) icmf ) , µ (0) ∞ = N (0 , I d x ) , (19) f ( k ) icmf , ∞ ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) + f (0) ntk , ∞ ( x ) , (20)where f (0) ntk , ∞ is defined similarly to above: f (0) ntk , ∞ ( x ) ∼ N (0 , σ ∗ , σ (0) , ( x )) , (21)the vector field ∆ θ ( k ) icmf is defined analogously to the mean-field case: ∆ θ ( k ) icmf (ˆ a, ˆ w ) = − [ ∇ ( k ) f icmf (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ∇ ( k ) f icmf (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (22)6he only difference between this dynamics and the mean-field dynamics is a bias term f (0) ntk , ∞ in thedefinition of logits. This bias term does not depend on k and stays finite for large d in contrast to f (0) mf , ∞ which vanishes for large d ; it ensures Condition 3-1 to hold. As for Condition 3-4, tangentkernels evolve with k simply because the measure µ ( k ) ∞ evolves with k similarly to the mean-fieldcase (see Figure 1, right). Indeed, K ( k ) w, ∞ ( x (cid:48) , x ) = σ ∗ , (cid:90) | ˆ a ( k ) | φ (cid:48) ( ˆ w ( k ) ,T x ) φ (cid:48) ( ˆ w ( k ) ,T x (cid:48) ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) , (23)and the limit of K ( k ) a,d is written in a similar way. Kernels at initialization K (0) a/w, ∞ are finite due to theLaw of Large Numbers (Condition 3-2); this, and the finiteness of f (0) ntk ensures Condition 3-3.As we show in SM D the dynamics (19) is a limit for the GD dynamics of the following model withlearning rates ˆ η a/w = ˆ η ∗ a/w d : f icmf ,d ( x ; ˆ a , ˆ W ) = σ ∗ d − d (cid:88) r =1 ˆ a r φ ( ˆ w Tr x ) + σ ∗ ( d − / − d − ) d (cid:88) r =1 ˆ a (0) r φ ( ˆ w (0) ,Tr x ) . (24)The reason for using a factor ( d − / − d − ) before the second term instead of d − / will be madeclear in the next Section. Note that this does not alter the limit. Consider a network of width d ∗ initialized with a standard deviation σ ∗ and trained with learningrates ˆ η ∗ a/w . We call this model a "reference". Consider a family of models indexed by a width d initialized with a standard deviation σ ( d ) and trained with learning rates ˆ η a/w ( d ) with followingproperties: (1) σ ( d ∗ ) = σ ∗ , ˆ η a/w ( d ∗ ) = ˆ η ∗ a/w , (2) σ ( d ) = Θ d →∞ ( d q σ ) , ˆ η a/w ( d ) = Θ d →∞ ( d ˜ q ) forsome pre-defined scaling exponents ( q σ , ˜ q ) .The first property ensures that a model of width d ∗ coincides with the reference model, while thesecond property ensures that a model converges to a limiting model defined by corresponding scalingparameters. Additionaly assume σ ( d ) ∝ d q σ , ˆ η a/w ( d ) ∝ d ˜ q . This ensures that a model of thereference width at the initialization f (0) d ∗ provides an unbiased estimate for a limit model at theinitialization f (0) ∞ , as well as kernels at the initialization K (0) a/w,d ∗ provide unbiased estimates forlimit kernels K (0) a/w, ∞ . Given this, we slightly abuse the notation and consider σ ( d ) = σ ∗ ( d/d ∗ ) q σ and ˆ η a/w ( d ) = ˆ η ∗ a/w ( d/d ∗ ) ˜ q . Here we note that the model (24) does not alter the limit behavior as d → ∞ , at the same time ensuring that the model for d = d ∗ coincides with the reference model.We train a reference network of width d ∗ = 128 for the binary classification with a cross-entropy losson the CIFAR2 dataset (a subset of first two classes of CIFAR10). We track the divergence of a limitnetwork from the reference one using the following quantity: E x ∼D test D logits ( f ( k ) ∞ ( x ) || f ( k ) d ∗ ( x )) ,where D logits ( ξ || ξ ∗ ) = KL( N ( E ξ, V ar ξ ) || N ( E ξ ∗ , V ar ξ ∗ )) . (25)Results are shown in Figure 2. As we see, the NTK limit tracks the reference network well onlyfor the first 20 training steps; similar observation has been already made by [10]. At the sametime, the mean-field limit starts with a high divergence (since the initial limit model is zero in thiscase), however, after the 80-th step, it becomes smaller than that of the NTK limit. This can be theimplication of non-stationary kernels. As for default case, divergence of logits results in blow-up ofthe KL-divergence.The best overall case is the proposed IC-MF limit, which retains the small KL-divergence related tothe reference model throughout the training process. A pioneering work of [1] have shown that a gradient descent training of a neural net can be viewedas a kernel gradient descent in the space of predictors. The corresponding kernel is called a neural7igure 2:
Initialization-corrected mean-field (IC-MF) limit captures the behavior of a givenfinite-width network best among other limit models.
We plot a KL-divergence of logits of differentinfinite-width limits of a fixed finite-width reference model relative to logits of this reference model.
Setup: we train a one hidden layer network with SGD on CIFAR2 dataset; see SM E for details.KL-divergences are estimated using gaussian fits with 10 samples.tangent kernel (NTK). Generally, NTK is random and non-stationary, however [1] have shown thatin the limit of infinite width it becomes constant given a network is parameterized appropriately.In this case the evolution of the model is determined by this constant kernel; see eq. (13). Thetraining regime when NTK is hardly varying is coined as "lazy training", as opposed to the "rich"training regime, when NTK evolves significantly [12]. While being theoretically appealing, "laziness"assumption turns out to have a number of limitations in explaining the success of deep learning[11, 13, 14].Another line of works considers the evolution of weights as an evolution of a weight-space measure,similar to eq. (14) [2, 3, 5, 6, 4, 7]. This weight-space measure becomes deterministic in the limit ofinfinite width, given the network is parameterized appropriately; the corresponding limit dynamicsis called "mean-field". Note that the parameterization required here for the convergence to a limitdynamics differs from the one used in NTK literature.Our framework for reasoning about scaling of hyperparameters is very similar in spirit to the one usedin [8]. However, there are several crucial differences. First, we do not consider weight increments,as well as a model decomposition, and do not try to estimate exponents of the former and for termsof the latter, which arguebly complicates the work of [8]. Instead, we present derivations in termsof the limit behavior of logits and kernels which appears to be simpler and clearer. Second, we donot assume constancy of ∇ f (cid:96) . Third, our criterion of "dynamical stability" of scaling is weakercompared to the one of [8] and more suitable for classification problems, since it allows for divergingor vanishing logits, as long as they give meaningful classification responses. Note that "intermediatelimits" investigated in [8] exactly correspond to limit models which satisfy Condition 3-2. Finally,both "sym-default" and IC-MF limit models we propose in the present work have not been discussedin the work of [8]. The current work follows a direction started in [8]: we study how one should scale hyperparametersof a neural network with a single hidden layer in order to converge to a "dynamically stable" limittraining dynamics. A weaker dynamical stability condition leads us to a richer class of possible limitmodels as compared to [8]. In particular, the class of limit models we consider includes a "default"limit model that corresponds to a network with infinitely large number of nodes and finite learningrates in the original parameterization. This "default" limit model does not satisfy a "dynamicalstability" condition in [8].Moreover, we show that the class of limit models that can be achieved by scaling hyperparameters offinite-width nets is finite. The space of hyperparameter scalings is divided by regions with certainconditions on the training dynamics, and each region corresponds to a single limit model. All ofthese conditions are satisfied by finite-width networks, but cannot be satisfied by limit models allsimultaneously. We propose a modification of a finite-width model; the limit of this modificationcorresponds to a limit model that satisfy all of the conditions mentioned above and tracks the dynamicsof a "reference" finite-width net better than other limit models.8 roader Impact
This is a theoretical work. For this reason it does not present any foreseeable societal consequence.
Acknowledgments
This work was supported by National Technology Initiative and PAO Sberbank projectID0000000007417F630002. We thank Mikhail Burtsev and Biswarup Das for valuable discus-sions and suggestions, as well as for help in improving the final version of the text.
References [1] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence andgeneralization in neural networks. In
Advances in neural information processing systems , pages8571–8580, 2018.[2] Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the landscape oftwo-layer neural networks.
Proceedings of the National Academy of Sciences , 115(33):E7665–E7671, 2018.[3] Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Mean-field theory of two-layersneural networks: dimension-free bounds and kernel limit. arXiv preprint arXiv:1902.06015 ,2019.[4] Grant M Rotskoff and Eric Vanden-Eijnden. Trainability and accuracy of neural networks: aninteracting particle system approach. stat , 1050:30, 2019.[5] Justin Sirignano and Konstantinos Spiliopoulos. Mean field analysis of neural networks: A lawof large numbers.
SIAM Journal on Applied Mathematics , 80(2):725–752, 2020.[6] Lenaic Chizat and Francis Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. In
Advances in neural information processingsystems , pages 3036–3046, 2018.[7] Dmitry Yarotsky. Collective evolution of weights in wide neural networks. arXiv preprintarXiv:1810.03974 , 2018.[8] Eugene A Golikov. Towards a general theory of infinite-width limits of neural classifiers. arXivpreprint arXiv:2003.05884 , 2020.[9] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers:Surpassing human-level performance on imagenet classification. In
Proceedings of the IEEEinternational conference on computer vision , pages 1026–1034, 2015.[10] Jaehoon Lee, Lechao Xiao, Samuel Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, and Jeffrey Pennington. Wide neural networks of any depth evolve as linear modelsunder gradient descent. In
Advances in neural information processing systems , pages 8570–8581,2019.[11] Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, Russ R Salakhutdinov, and Ruosong Wang.On exact computation with an infinitely wide neural net. In
Advances in Neural InformationProcessing Systems , pages 8139–8148, 2019.[12] Blake Woodworth, Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Kerneland deep regimes in overparametrized models. arXiv preprint arXiv:1906.05827 , 2019.[13] Yu Bai and Jason D Lee. Beyond linearization: On quadratic and higher-order approximationof wide neural networks. arXiv preprint arXiv:1910.01619 , 2019.[14] Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Limitations oflazy training of two-layers neural network. In
Advances in Neural Information ProcessingSystems , pages 9108–9118, 2019.[15] Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito,Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation inpytorch. 2017. 9
Proofs of propositions
We restate all necessary definitions here. We assume the non-linearity φ to be real analytic andasymptotically linear: φ ( z ) = Θ z →∞ ( z ) . We assume the loss function (cid:96) ( y, z ) to be the standardbinary cross-entropy loss: (cid:96) ( y, z ) = ln(1 + e − yz ) , where labels y ∈ {− , } .The training dynamics is given as: ∆ˆ a ( k ) r = − ˆ η a σ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) , ˆ a (0) r ∼ N (0 , , (26) ∆ ˆ w ( k ) r = − ˆ η w σ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) ˆ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) x ( k ) w , ˆ w (0) r ∼ N (0 , I ) ∀ r ∈ [ d ] , (27) ∇ ( k ) f d (cid:96) ( x , y ) = ∂(cid:96) ( y, z ) ∂z (cid:12)(cid:12)(cid:12)(cid:12) z = f ( k ) d ( x ) = − y f ( k ) d ( x ) y ) , f ( k ) d ( x ) = σ d (cid:88) r =1 ˆ a ( k ) r φ ( ˆ w ( k ) ,Tr x ) , where ( x ( k ) a/w , y ( k ) a/w ) ∼ D for D being the data distribution.We assume hyperparameters to be scaled with width as power-laws: σ ( d ) = σ ∗ d q σ , ˆ η a ( d ) = ˆ η ∗ a d ˜ q a , ˆ η w ( d ) = ˆ η ∗ w d ˜ q w . Since φ is smooth, we have: ∆ f ( k ) d ( x ) = f ( k +1) d ( x ) − f ( k ) d ( x ) = d (cid:88) r =1 ∂f d ( x ) ∂ ˆ θ r (cid:12)(cid:12)(cid:12)(cid:12) ˆ θ r =ˆ θ ( k ) r ∆ˆ θ ( k ) r + O ˆ η ∗ a/w → (ˆ η ∗ a ˆ η ∗ w + ˆ η ∗ , w ) == − ˆ η ∗ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) K ( k ) a,d ( x , x ( k ) a ) − ˆ η ∗ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) K ( k ) w,d ( x , x ( k ) w ) + O (ˆ η ∗ a ˆ η ∗ w + ˆ η ∗ , w ) , (28)where we have defined kernels: K ( k ) a,d ( x , x (cid:48) ) = d ˜ q a σ d (cid:88) r =1 φ ( ˆ w ( k ) ,Tr x ) φ ( ˆ w ( k ) ,Tr x (cid:48) ) , (29) K ( k ) w,d ( x , x (cid:48) ) = d ˜ q w σ d (cid:88) r =1 | ˆ a ( k ) r | φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) x T x (cid:48) . (30)Define the following quantity: ∆ f ( k ) , (cid:48) d,a/w ( x ) = ∂ ∆ f ( k ) d ( x ) ∂ ˆ η ∗ a/w (cid:12)(cid:12)(cid:12)(cid:12)(cid:12) ˆ η ∗ a =0ˆ η ∗ w =0 = −∇ ( k ) f d (cid:96) ( x ( k ) a/w , y ( k ) a/w ) K ( k ) a/w,d ( x , x ( k ) a/w ) . (31)We use this quantity to rewrite the model increment: ∆ f ( k ) d ( x ) = ˆ η ∗ a ∆ f ( k ) , (cid:48) d,a ( x ) + ˆ η ∗ w ∆ f ( k ) , (cid:48) d,w ( x ) + O (ˆ η ∗ a ˆ η ∗ w + ˆ η ∗ , w ) . (32)Define p ( k ) err,d = P ( y, x ,y (: k − a , x (: k − a ,y (: k − w , x (: k − w ) ∼D k − { yf ( k ) d ( x ) < } — the probability ofgiving a wrong answer on the step k . Let k term,d ∈ N ∪ { + ∞} be a maximal k such that ∀ k (cid:48) < kp ( k (cid:48) ) err,d > . Generally, k term,d depends on hyperparameters, as well as on the data distribution D .We formualte the following condition of dynamical stability: Condition 4 (Condition 2, restated) . Given k term, ∞ = + ∞ , ∃ k balance ∈ N : ∀ σ ∗ > ∀ ˆ η ∗ a/w > ∀ k ≥ k balance y ( k ) a f ( k ) ∞ ( x ( k ) a ) < and y ( k ) w f ( k ) ∞ ( x ( k ) w ) < imply ∆ f ( k ) , (cid:48) d,a/w ( x ) =Θ d →∞ ( f ( k balance ) d ( x )) x -a.e. ( y (: k ) a/w , x (: k ) a/w ) -a.s. .1 Proof of Proposition 1 Define: q ( k ) θ = inf { q : θ ( k )1 = O d →∞ ( d q ) } , q ( k )∆ θ = inf { q : ∆ θ ( k )1 = O d →∞ ( d q ) } , (33)where θ should be substituted with a or w . We define inf( ∅ ) = + ∞ . We introduce similar definitionsfor other quantities: q ( k ) f ( x ) = inf { q : f ( k ) d ( x ) = O d →∞ ( d q ) } , q ( k ) ∇ (cid:96) ( x , y ) = inf { q : ∇ ( k ) f d (cid:96) ( x , y ) = O d →∞ ( d q ) } , (34) q ( k )∆ f ( x ) = inf { q : ∆ f ( k ) d ( x ) = O d →∞ ( d q ) } , q ( k )∆ f (cid:48) a/w ( x ) = inf { q : ∆ f ( k ) , (cid:48) d,a/w ( x ) = O d →∞ ( d q ) } . (35) Lemma 1.
Assume D is a continuous distribution. Then following hold:1. ∀ k ≥ ∀ x , y q ( k ) ∇ (cid:96) ( x , y ) ≤ , while [ yf ( k ) ∞ ( x ) < implies q ( k ) ∇ (cid:96) ( x , y ) = 0 .2. q (0) a/w = 0 , q (0) f ( x ) = q σ + x -a.e.3. ∀ k ≥ q ( k )∆ a/ ∆ w = ˜ q a/w + q σ + q ( k ) w/a + q ( k ) ∇ (cid:96) ( x ( k ) , y ( k ) ) ( x ( k ) a/w , y ( k ) a/w ) -a.s.4. ∀ k ≥ q ( k )∆ f (cid:48) a/w ( x ) = 2 q σ + 1 + ˜ q a/w + 2 q ( k ) w/a + q ( k ) ∇ (cid:96) ( x ( k ) a/w , y ( k ) a/w ) x -a.e. ( x ( k ) a/w , y ( k ) a/w ) -a.s.5. ∀ k ≥ q σ + ˜ q w + q ( k ) a ≤ implies that for sufficiently small ˆ η ∗ a and ˆ η ∗ w q ( k )∆ f ( x ) =max( q ( k )∆ f (cid:48) a ( x ) , q ( k )∆ f (cid:48) w ( x )) x -a.e. ( x ( k ) a , y ( k ) a , x ( k ) w , y ( k ) w ) -a.s.6. ∀ k ≥ q ( k +1) a/w = max( q ( k ) a/w , q ( k )∆ a/ ∆ w ) ( x ( k ) , y ( k ) ) -a.s., q ( k +1) f ( x ) =max( q ( k ) f ( x ) , q ( k )∆ f ( x )) x -a.e. ( x ( k ) a , y ( k ) a , x ( k ) w , y ( k ) w ) -a.s.Proof. (1) follows from the fact that ∂(cid:96) ( y, z ) /∂z is bounded ∀ y , while | ∂(cid:96) ( y, z ) /∂z | ∈ [1 / , when yz < . ˆ a (0) r ∼ N (0 , which is not zero and does not depend on d , hence q (0) a = 0 ; similar holds for w .For x (cid:54) = 0 we have f (0) d ( x ) = σ (cid:80) dr =1 ˆ a (0) r φ ( ˆ w (0) ,Tr x ) = Θ d →∞ ( d / q σ ) due to the Central LimitTheorem. Hence (2) holds.Since D is a.c. wrt Lebesgue measure on R d x , and φ is real analytic and non-zero, φ ( ˆ w ( k ) ,Tr x ( k ) a/w ) (cid:54) = 0 and φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) a/w ) is well-defined ( x ( k ) a/w , y ( k ) a/w ) -a.s. This implies that q ( k )∆ a/ ∆ w = ˜ q a/w + q σ + q ( k ) w/a + q ( k ) ∇ (cid:96) ( x ( k ) a/w , y ( k ) a/w ) ( x ( k ) a/w , y ( k ) a/w ) -a.s., which is exactly (3).Consider ∆ f ( k ) , (cid:48) d,a : ∆ f ( k ) , (cid:48) d,a ( x ) = −∇ ( k ) f (cid:96) ( x ( k ) a , y ( k ) a ) K ( k ) a,d ( x , x ( k ) a ) == −∇ ( k ) f (cid:96) ( x ( k ) a , y ( k ) a ) d ˜ q a σ d (cid:88) r =1 φ ( ˆ w ( k ) ,Tr x ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) . (36)For the same reason as discussed above φ ( ˆ w ( k ) ,Tr x ( k ) a ) (cid:54) = 0 ( x ( k ) a , y ( k ) a ) -a.s., and φ ( ˆ w ( k ) ,Tr x ) (cid:54) = 0 x -a.e. Since the summands are distributed identically and are generally non-zero, the sum introducesa factor of d by the law of large numbers. Since φ is asymptotically linear, each φ -term scales as d q ( k ) w . Collecting all terms together, we obtain q ( k )∆ f (cid:48) a ( x ) = 2 q σ + 1 + ˜ q a + 2 q ( k ) w + q ( k ) ∇ (cid:96) ( x ( k ) a , y ( k ) a ) x -a.e. ( x ( k ) a , y ( k ) a ) -a.s. Following the same steps for ∆ f ( k ) , (cid:48) w , we get (4).11et us overview ∆ f ( k ) d ( x ) in detail: ∆ f ( k ) d ( x ) = d (cid:88) r =1 (cid:32) ∞ (cid:88) j =1 j ! ∂ j f d ( x ) ∂ ˆ w i r . . . ∂ ˆ w i j r (cid:12)(cid:12)(cid:12)(cid:12) ˆ w r = ˆ w ( k ) r ˆ a r =ˆ a ( k ) r ∆ ˆ w ( k ) ,i r . . . ∆ ˆ w ( k ) ,i j r ++ ∞ (cid:88) j =1 j ! ∂ j f d ( x ) ∂ ˆ a r ∂ ˆ w i r . . . ∂ ˆ w i j r (cid:12)(cid:12)(cid:12)(cid:12) ˆ w r = ˆ w ( k ) r ˆ a r =ˆ a ( k ) r ∆ˆ a r ∆ ˆ w ( k ) ,i r . . . ∆ ˆ w ( k ) ,i j r (cid:33) == d (cid:88) r =1 (cid:32) ∞ (cid:88) j =1 j ! ( − j ˆ η jw σ j +1 ( ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )) j ×× (ˆ a ( k ) r ) j +1 ( φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w )) j φ ( j ) ( ˆ w ( k ) ,Tr x )( x ( k ) ,Tw x ) j ++ ∞ (cid:88) j =1 j ! ( − j ˆ η a ˆ η j − w σ j +1 ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a )( ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )) j − ×× (ˆ a ( k ) r ) j − φ ( ˆ w ( k ) ,Tr x ( k ) a )( φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w )) j − φ ( j − ( ˆ w ( k ) ,Tr x )( x ( k ) ,Tw x ) j − (cid:33) . (37)Assumption q σ + ˜ q w + q ( k ) a ≤ implies ˆ η jw σ j +1 (ˆ a ( k ) r ) j +1 = O d →∞ (ˆ η w σ (ˆ a ( k ) r ) ) and ˆ η a ˆ η j − w σ j +1 (ˆ a ( k ) r ) j − = O d →∞ (ˆ η a σ ) .Since q ( k ) ∇ (cid:96) ( x , y ) ≤ ∀ x , y due to (1), ( ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )) j = O d →∞ ( ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )) and ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a )( ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )) j − = O d →∞ ( ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a )) .Since φ ( z ) = Θ z →∞ ( z ) , φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) = O d →∞ (1) and ( φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w )) j = O d →∞ ( φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w )) for j ≥ .Hence for small enough ˆ η ∗ a and ˆ η ∗ w the first term of each sum which corresponds to j = 1 dominatesall others, even in the limit of infinite d : ∆ f ( k ) d ( x ) = − d (cid:88) r =1 (cid:32) ˆ η w σ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )(ˆ a ( k ) r ) φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) φ (cid:48) ( ˆ w ( k ) ,Tr x ) x ( k ) ,Tw x ++ ˆ η a σ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) φ ( ˆ w ( k ) ,Tr x )++ o ˆ η ∗ a/w → (cid:16) O d →∞ (cid:16)(cid:16) ˆ η a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) + ˆ η w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )(ˆ a ( k ) r ) (cid:17) σ (cid:17)(cid:17)(cid:33) == ˆ η ∗ w ∆ f ( k ) , (cid:48) d,w ( x ) + ˆ η ∗ a ∆ f ( k ) , (cid:48) d,a ( x )++ o ˆ η ∗ a/w → (cid:16) O d →∞ (cid:16)(cid:16) ˆ η a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) + ˆ η w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )(ˆ a ( k ) r ) (cid:17) σ d (cid:17)(cid:17) . (38)Note that two summands depend on ( x ( k ) w , y ( k ) w ) and ( x ( k ) a , y ( k ) a ) respectively, which do not dependon each other. Hence q ( k )∆ f ( x ) = max( q ( k )∆ f (cid:48) a ( x ) , q ( k )∆ f (cid:48) w ( x )) x -a.e. ( x ( k ) a/w , y ( k ) a/w ) -a.s., which is (5).Note that the o-term does not alter the exponent. Indeed, (cid:16) ˆ η a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) + ˆ η w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )(ˆ a ( k ) r ) (cid:17) σ d == O d →∞ (cid:16) d q σ +max(˜ q a + q ( k ) ∇ (cid:96) ( x ( k ) a ,y ( k ) a ) , ˜ q w + q ( k ) ∇ (cid:96) ( x ( k ) w ,y ( k ) w )+2 q ( k ) a ) (cid:17) == O d →∞ (cid:16) d q σ +max(˜ q a + q ( k ) ∇ (cid:96) ( x ( k ) a ,y ( k ) a )+2 q ( k ) w , ˜ q w + q ( k ) ∇ (cid:96) ( x ( k ) w ,y ( k ) w )+2 q ( k ) a ) (cid:17) == O d →∞ (cid:18) d max( q ( k )∆ f (cid:48) a ( x ) ,q ( k )∆ f (cid:48) w ( x )) (cid:19) . (39)One before the last equality holds, because q ( k ) w ≥ due to (2) and (6), while the last equality holdsdue to (4). 12y definition we have ˆ a ( k +1) r = ˆ a ( k ) r + ∆ˆ a ( k ) r . Since the second term depends on ( x ( k ) a , y ( k ) a ) , whilethe first term does not, we get q ( k +1) a = max( q ( k ) a , q ( k )∆ a ) . Similar holds for ˆ w r and f ( k ) d ( x ) x -a.e.,which gives (6). Lemma 2.
Assume D is a continuous distribution, k term = + ∞ and ˜ q a = ˜ q w = ˜ q . Then1. If q σ + ˜ q ≤ then ∀ k ≥ q ( k ) a/w = 0 ( x (: k − a , y (: k − a , x (: k − w , y (: k − w ) -a.s.2. If q σ + ˜ q > then ∀ k ≥ q ( k ) a/w = k ( q σ + ˜ q ) with positive probability wrt ( x (: k − a , y (: k − a , x (: k − w , y (: k − w ) .Proof. Here and in subsequent proofs we will write "almost surely" meaning "almost surely wrt ( x (: k ) a , y (: k ) a , x (: k ) w , y (: k ) w ) " for appropriate k ; we apply a similar shortening for "with positive probabil-ity wrt ( x (: k ) a , y (: k ) a , x (: k ) w , y (: k ) w ) ".If q σ + ˜ q ≤ then statements 1, 2, 3 and 6 of Lemma 1 imply ∀ k ≥ q ( k ) a/w = 0 a.s.Assume q σ + ˜ q > . We will prove that ∀ k ≥ q ( k ) a/w = max(0 , k ( q σ + ˜ q )) with positive probabilityby induction. Induction base is given by Lemma 1-2.Combining the induction assumption and Lemma 1-3 we get q ( k )∆ a/ ∆ w = ( k + 1)( q σ + ˜ q ) + q ( k ) ∇ (cid:96) ( x ( k ) , y ( k ) ) with positive probability wrt ( x (: k − a , y (: k − a , x (: k − w , y (: k − w ) ( x ( k ) a/w , y ( k ) a/w ) -a.s.Since k term = + ∞ > k , y ( k ) a/w f ( k ) d ( x ( k ) a/w ) < with positive probability wrt ( x ( k ) a , y ( k ) a , x ( k ) w , y ( k ) w ) ,and Lemma 1-1 implies that q ( k )∆ a/ ∆ w = ( k + 1)( q σ + ˜ q ) with positive probability wrt ( x (: k ) a , y (: k ) a , x (: k ) w , y (: k ) w ) .Finally, Lemma 1-6 concludes the proof of the induction step. Lemma 3.
Assume D is a continuous distribution, k term, ∞ = + ∞ , ˜ q a = ˜ q w = ˜ q and q σ + ˜ q ≤ .Then ∀ k ≥ y ( k ) a/w f ( k ) ∞ ( x ( k ) a/w ) < implies q ( k )∆ f (cid:48) a/w ( x ) = 2 q σ + 1 + ˜ q x -a.e. ( x (: k − a , y (: k − a , x (: k − w , y (: k − w ) -a.s.2. y ( k ) a f ( k ) ∞ ( x ( k ) a ) < and y ( k ) w f ( k ) ∞ ( x ( k ) w ) < imply q ( k )∆ f ( x ) = 2 q σ + 1 + ˜ q x -a.e. ( x (: k − a , y (: k − a , x (: k − w , y (: k − w ) -a.s. for sufficiently small ˆ η ∗ a and ˆ η ∗ w .Proof. By Lemma 2 ∀ k ≥ q ( k ) a/w = 0 a.s. Since y ( k ) a/w f ( k ) ∞ ( x ( k ) a/w ) < , q ( k ) ∇ (cid:96) = 0 due to Lemma 1-1.Given this, Lemma 1-4 implies ∀ k ≥ q ( k )∆ f (cid:48) a/w ( x ) = 2 q σ + 1 + ˜ q x -a.e. a.s. Hence by virtue ofLemma 1-5 ∀ k ≥ q ( k )∆ f ( x ) = 2 q σ + 1 + ˜ q x -a.e. a.s. for sufficiently small ˆ η ∗ a and ˆ η ∗ w . Proposition 3.
Suppose ˜ q a = ˜ q w = ˜ q and D is a continuous distribution. Then Condition 4 requires q σ + ˜ q ∈ [ − / , to hold.Proof. By Lemma 2 if q σ + ˜ q > then q ( k ) a/w = k ( q σ + ˜ q ) with positive probability. At the sametime by virtue of Lemma 1-1 k term, ∞ = + ∞ implies q ( k ) ∇ (cid:96) = 0 with positive probability. Giventhis, Lemma 1-4 implies q ( k )∆ f (cid:48) a/w ( x ) = q σ + 1 + (2 k + 1)( q σ + ˜ q ) x -a.e. with positive probability.This means that the last quantity cannot be almost surely equal to q ( k balance ) f ( x ) for any k balance independent on k . Since ∆ f ( k ) , (cid:48) d,a/w ( x ) = Θ d →∞ ( f ( k balance ) d ( x )) requires q ( k )∆ f (cid:48) a/w ( x ) = q ( k balance ) f ( x ) ,we conclude that Condition 4 cannot be satisfied if q σ + ˜ q > .13ence q σ + ˜ q ≤ . Then by Lemma 3 ∀ k ≥ y ( k ) a f ( k ) ∞ ( x ( k ) a ) < and y ( k ) w f ( k ) ∞ ( x ( k ) w ) < imply q ( k )∆ f ( x ) = 2 q σ + 1 + ˜ q x -a.e. ( x (: k − a , y (: k − a , x (: k − w , y (: k − w ) -a.s. for sufficiently small ˆ η ∗ a and ˆ η ∗ w . We will show that Condition 4 requires q σ + ˜ q ∈ [ − / , to hold already for these sufficientlysmall ˆ η ∗ a and ˆ η ∗ w .Suppose y ( k ) a f ( k ) ∞ ( x ( k ) a ) < and y ( k ) w f ( k ) ∞ ( x ( k ) w ) < . Given this, points 1 and 6 of Lemma 1 imply ∀ k balance ≥ q k balance f ( x ) = max( q (0) f ( x ) , q σ + 1 + ˜ q ) = max( q σ + , q σ + 1 + ˜ q ) x -a.e. a.s.Hence q ( k )∆ f (cid:48) a/w ( x ) = q ( k balance ) f ( x ) x -a.e. a.s. if and only if q σ + ≤ q σ + 1 + ˜ q , which is q σ + ˜ q ≥− / ; we can take k balance = 1 without loss of generality. Having q ( k )∆ f (cid:48) a/w ( x ) = q ( k balance ) f ( x ) isnecessary to have ∆ f ( k ) , (cid:48) d,a/w ( x ) = Θ d →∞ ( f ( k balance ) d ( x )) .Summing all together, Condition 4 requires q σ + ˜ q ∈ [ − / , to hold. A.2 Proof of Proposition 2Proposition 4.
Let Condition 4 holds; then1. f (0) d ( x ) = Θ d →∞ (1) x -a.e. is equivalent to q σ + 1 / .2. K (0) d,a/w ( x , x (cid:48) ) = Θ d →∞ (1) ( x , x (cid:48) ) -a.e. is equivalent to q σ + ˜ q + 1 = 0 .3. K (0) d,a/w ( x , x (cid:48) ) = Θ d →∞ ( f (0) d ( x )) ( x , x (cid:48) ) -a.e. is equivalent to q σ + ˜ q + 1 / .4. ∆ K (0) , (cid:48) d,wa/w ( x , x (cid:48) ) = Θ d →∞ ( K (0) d,w ( x , x (cid:48) )) ( x , x (cid:48) ) -a.e. and ∆ K (0) , (cid:48) d,aw ( x , x (cid:48) ) =Θ d →∞ ( K (0) d,a ( x , x (cid:48) )) ( x , x (cid:48) ) -a.e. is equivalent to q σ + ˜ q = 0 .Proof. Statement (1) directly follows from Lemma 1-2: f (0) d ( x ) = σ d (cid:88) r =1 ˆ a (0) r φ ( ˆ w (0) ,Tr x ) = Θ d →∞ ( d q σ +1 / ) (40) ( x ) -a.e. due to the Central Limit Theorem.Statement (2) follows from the definition of kernels and the Law of Large Numbers: K (0) a,d ( x , x (cid:48) ) = d ˜ q σ d (cid:88) r =1 φ ( ˆ w (0) ,Tr x ) φ ( ˆ w (0) ,Tr x (cid:48) ) = Θ d →∞ ( d ˜ q +2 q σ +1 ) (41) ( x , x (cid:48) ) -a.e.; the same logic holds for the other kernel: K (0) w,d ( x , x (cid:48) ) = Θ d →∞ ( d ˜ q +2 q σ +1 ) ( x , x (cid:48) ) -a.e.Combining derivations of the two previous statements, we get the statement (3). Now we proceed tothe last statement. Consider again the kernel K (0) a,d ; its increment is given by: ∆ K (0) a,d ( x , x (cid:48) ) = − ˆ η ∗ w d q σ d (cid:88) r =1 (cid:16) φ ( ˆ w (0) ,Tr x ) φ (cid:48) ( ˆ w (0) ,Tr x (cid:48) ) + φ (cid:48) ( ˆ w (0) ,Tr x ) φ ( ˆ w (0) ,Tr x (cid:48) ) (cid:17) ×× ∇ ( k ) f d (cid:96) ( x (0) w , y (0) w )ˆ a (0) r φ (cid:48) ( ˆ w (0) ,Tr x (0) w )( x + x (cid:48) ) T x (0) w + O ˆ η ∗ w → d →∞ (ˆ η ∗ , w d q +4 q σ +1 ) , (42)Consider a linear part of this increment with respect to proportionality factors of learning rates: ∆ K (0) , (cid:48) aw,d ( x , x (cid:48) ) = ∂ ∆ K (0) a,d ( x , x (cid:48) ) ∂ ˆ η ∗ w (cid:12)(cid:12)(cid:12)(cid:12)(cid:12) ˆ η ∗ w =0 == − d q σ d (cid:88) r =1 (cid:16) φ ( ˆ w (0) ,Tr x ) φ (cid:48) ( ˆ w (0) ,Tr x (cid:48) ) + φ (cid:48) ( ˆ w (0) ,Tr x ) φ ( ˆ w (0) ,Tr x (cid:48) ) (cid:17) ×× ∇ (0) f d (cid:96) ( x (0) w , y (0) w )ˆ a (0) r φ (cid:48) ( ˆ w (0) ,Tr x (0) w )( x + x (cid:48) ) T x (0) w , (43)14ence ∆ K (0) , (cid:48) aw,d = Θ d →∞ ( K (0) a,d ) is equivalent to q σ + ˜ q = 0 . Considering the second kernel K ( k ) w,d and its increment is equivalent to the same condition. B The number of distinct limit models is finite
It is easy to see that due to the Proposition 4 Condition 3 divides the well-definiteness band into 13regions. We now show that when proportionality factors σ ∗ and ˆ η ∗ a/w are fixed, choosing a limitmodel evolution is equivalent to picking a single region from these 13.Indeed, for any width d a model evolution can be written as follows: ∆ f ( k ) d ( x ) = − ˆ η ∗ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) (cid:16) K ( k ) w,d ( x , x ( k ) w )++ O ˆ η ∗ a/w → d →∞ (ˆ η ∗ w ∆ K ( k ) , (cid:48) ww,d ( x , x ( k ) w ) + ˆ η ∗ a ∆ K ( k ) , (cid:48) wa,d ( x , x ( k ) w )) (cid:17) −− ˆ η ∗ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) (cid:16) K ( k ) a,d ( x , x ( k ) a ) + O ˆ η ∗ w → d →∞ (ˆ η ∗ w ∆ K ( k ) , (cid:48) aw,d ( x , x ( k ) a )) (cid:17) . (44) f ( k +1) d ( x ) = f ( k ) d ( x ) + ∆ f ( k ) d ( x ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) , (45) f (0) d ( x ) = σ ∗ d q σ d (cid:88) r =1 ˆ a (0) r φ ( ˆ w (0) ,Tr x ) , (ˆ a (0) r , ˆ w (0) r ) ∼ N (0 , I d x ) . (46)Now we introduce normalized kernels: ˜ K ( k ) a,d ( x , x (cid:48) ) = d − − ˜ q − q σ K ( k ) a,d ( x , x (cid:48) ) = σ ∗ , d − d (cid:88) r =1 φ ( ˆ w ( k ) ,Tr x ) φ ( ˆ w ( k ) ,Tr x (cid:48) ) , (47) ˜ K ( k ) w,d ( x , x (cid:48) ) = d − − ˜ q − q σ K ( k ) w,d ( x , x (cid:48) ) = σ ∗ , d − d (cid:88) r =1 | ˆ a ( k ) r | φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) x T x (cid:48) . (48)Note that after normalization kernels stay finite in the limit of large width due to the Law of LargeNumbers. Similarly, we normalize logits, as well as kernel and logit increments: ∆ ˜ K ( k ) , (cid:48)∗∗ ,d = d − − ˜ q − q σ ∆ K ( k ) , (cid:48)∗∗ ,d , ∆ ˜ f ( k ) d = d − − ˜ q − q σ ∆ f ( k ) d , ˜ f ( k ) d = d − − ˜ q − q σ f ( k ) d . (49)We then rewrite the model evolution as: ∆ ˜ f ( k ) d ( x ) = − ˆ η ∗ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) (cid:16) ˜ K ( k ) w,d ( x , x ( k ) w )++ O ˆ η ∗ a/w → d →∞ (ˆ η ∗ w ∆ ˜ K ( k ) , (cid:48) ww,d ( x , x ( k ) w ) + ˆ η ∗ a ∆ ˜ K ( k ) , (cid:48) wa,d ( x , x ( k ) w )) (cid:17) −− ˆ η ∗ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) (cid:16) ˜ K ( k ) a,d ( x , x ( k ) a ) + O ˆ η ∗ w → d →∞ (ˆ η ∗ w ∆ ˜ K ( k ) , (cid:48) aw,d ( x , x ( k ) a )) (cid:17) . (50) ˜ f ( k +1) d ( x ) = ˜ f ( k ) d ( x ) + ∆ ˜ f ( k ) d ( x ) ∀ k ≥ , (51) ˜ f (0) d ( x ) = σ ∗ d − − ˜ q − q σ d (cid:88) r =1 ˆ a (0) r φ ( ˆ w (0) ,Tr x ) , (ˆ a (0) r , ˆ w (0) r ) ∼ N (0 , I d x ) , (52) f ( k ) d ( x ) = d q +2 q σ ˜ f ( k ) d ( x ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (53)15 .1 Constant normalized kernels case Kernels ˜ K ( k ) a/w,d are either constants (hence ∆ ˜ K ( k ) , (cid:48)∗∗ ,d → as d → ∞ ) or evolve with k in the limit oflarge d . First assume they are constants; in this case q σ + ˜ q < due to Proposition 4-4, and ∆ ˜ f ( k ) d ( x ) = − ˆ η ∗ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) (cid:16) ˜ K (0) w,d ( x , x ( k ) w ) + o ˆ η ∗ a/w → d →∞ (1) (cid:17) −− ˆ η ∗ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) (cid:16) ˜ K (0) a,d ( x , x ( k ) a ) + o ˆ η ∗ w → d →∞ (1) (cid:17) . (54)Since normalized kernels ˜ K (0) a/w,d converge to non-zero limit kernels ˜ K (0) a/w, ∞ , we can rewrite theformula above as: ∆ ˜ f ( k ) d ( x ) = − ˆ η ∗ w ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) (cid:16) ˜ K (0) w, ∞ ( x , x ( k ) w ) + o d →∞ (1) (cid:17) −− ˆ η ∗ a ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) (cid:16) ˜ K (0) a, ∞ ( x , x ( k ) a ) + o d →∞ (1) (cid:17) . (55) ˜ f (0) d ( x ) = σ ∗ d − / − ˜ q − q σ ( N (0 , σ (0) , ( x )) + o d →∞ (1)) , (56)where σ (0) ( x ) can be calculated in the same manner as in [10]. As required by Proposition 3 / q + q σ ≥ , hence ˜ f (0) d ( x ) = O d →∞ (1) . This implies the following: ∇ (0) f ∞ (cid:96) ( x , y ) = lim d →∞ ∇ (0) f d (cid:96) ( x , y ) == lim d →∞ − y d q +2 q σ ˜ f (0) d ( x ) y ) = − y [ N (0 , σ (0) , ( x )) y < for / q σ > ; − y σ ∗ N (0 ,σ (0) , ( x )) y ) for / q σ = 0 ; − y/ for / q σ < . (57)On the other hand, ∆ ˜ f (0) d ( x ) = Θ d →∞ (1) with positive probability over ( x (0) a/w , y (0) a/w ) . Hence ˜ f (0) d = O d →∞ (∆ ˜ f (0) d ) and ˜ f (1) d = ˜ f (0) d + ∆ ˜ f (0) d = Θ d →∞ (1) . For the same reason, ˜ f ( k +1) d =˜ f ( k ) d + ∆ ˜ f ( k ) d = Θ d →∞ (1) ∀ k ≥ .This implies the following: ∀ k ≥ ∇ ( k +1) f ∞ (cid:96) ( x , y ) = lim d →∞ ∇ ( k +1) f d (cid:96) ( x , y ) = lim d →∞ − y d q +2 q σ ˜ f ( k +1) d ( x ) y ) == − y [lim d →∞ ˜ f ( k +1) d ( x ) y < for q + 2 q σ > ; − y d →∞ ˜ f ( k +1) d ( x ) y ) for q + 2 q σ = 0 ; − y/ for q + 2 q σ < . (58)If we define f ( k ) ∞ ( x ) = lim d →∞ f ( k ) d ( x ) , we get the following limit dynamics: ∆ ˜ f ( k ) ∞ ( x ) = − ˆ η ∗ w ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w ) ˜ K (0) w, ∞ ( x , x ( k ) w ) − ˆ η ∗ a ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) ˜ K (0) a, ∞ ( x , x ( k ) a ) , (59) ˜ K (0) a, ∞ ( x , x (cid:48) ) = σ ∗ , E ˆ w ∼N (0 ,I d x ) φ ( ˆ w T x ) φ ( ˆ w T x (cid:48) ) , (60) ˜ K (0) w, ∞ ( x , x (cid:48) ) = σ ∗ , E (ˆ a, ˆ w ) ∼N (0 ,I d x ) | ˆ a | φ (cid:48) ( ˆ w T x ) φ (cid:48) ( ˆ w T x (cid:48) ) x T x (cid:48) , (61) ˜ f ( k +1) ∞ ( x ) = ˜ f ( k ) ∞ ( x ) + ∆ ˜ f ( k ) ∞ ( x ) , ˜ f (0) ∞ ( x ) = (cid:26) σ ∗ N (0 , σ (0) , ( x )) for / q + q σ = 0 ; for / q + q σ > ;(62)16 (0) f ∞ (cid:96) ( x , y ) = − y [ N (0 , σ (0) , ( x )) y < for / q σ > ; − y σ ∗ N (0 ,σ (0) , ( x )) y ) for / q σ = 0 ; − y/ for / q σ < ; (63) ∇ ( k +1) f ∞ (cid:96) ( x , y ) = − y [ ˜ f ( k +1) ∞ ( x ) y < for q + 2 q σ > ; − y f ( k +1) ∞ ( x ) y ) for q + 2 q σ = 0 ; − y/ for q + 2 q σ < ; ∀ k ≥ . (64)This dynamics is defined by proportionality factors σ ∗ , ˆ η ∗ a/w and signs of three exponents: / q σ , q + 2 q σ and / q + q σ . Since we assume proportionality factors to be fixed, choosing signsof exponents is equivalent to choosing a limit model. Note that these exponents exactly correspond tothose mentioned in Proposition 4, points 1, 2 and 3. One can easily notice from Figure 1 (left) thatgiven q σ + ˜ q < , there are 8 distinct sign configurations.Note also that since we are interested in binary classification problems, only the sign of logitsmatters. Since f ( k ) d = d q +2 q σ ˜ f ( k ) d , signs of f ( k ) d and of ˜ f ( k ) d are the same for all d . Hence ∀ x , y lim d →∞ sign( f ( k ) d ( x )) = lim d →∞ sign( ˜ f ( k ) d ( x )) = sign( ˜ f ( k ) ∞ ( x )) . B.1.1 NTK limit model
We state here a special case of the NTK scaling ( q σ = − / , ˜ q = 0 , see [1]) explicitly. Since in thiscase q + 2 q σ , we can omit tildas everywhere. This results in the following limit dynamics: ∆ f ( k ) ∞ ( x ) = − ˆ η ∗ w ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w ) K (0) w, ∞ ( x , x ( k ) w ) − ˆ η ∗ a ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) K (0) a, ∞ ( x , x ( k ) a ) , (65) K (0) a, ∞ ( x , x (cid:48) ) = σ ∗ , E ˆ w ∼N (0 ,I d x ) φ ( ˆ w T x ) φ ( ˆ w T x (cid:48) ) , (66) K (0) w, ∞ ( x , x (cid:48) ) = σ ∗ , E (ˆ a, ˆ w ) ∼N (0 ,I d x ) | ˆ a | φ (cid:48) ( ˆ w T x ) φ (cid:48) ( ˆ w T x (cid:48) ) x T x (cid:48) , (67) f ( k +1) ∞ ( x ) = f ( k ) ∞ ( x ) + ∆ f ( k ) ∞ ( x ) , f (0) ∞ ( x ) = σ ∗ N (0 , σ (0) , ( x )) , (68) ∇ ( k ) f ∞ (cid:96) ( x , y ) = − y f ( k ) ∞ ( x ) y ) ∀ k ≥ . (69) B.2 Non-stationary normalized kernels case
Suppose now q σ + ˜ q = 0 . In this case ∆ K (0) , (cid:48) d,wa/w ( x , x (cid:48) ) = Θ d →∞ ( K (0) d,w ( x , x (cid:48) )) ( x , x (cid:48) ) -a.e. and ∆ K (0) , (cid:48) d,aw ( x , x (cid:48) ) = Θ d →∞ ( K (0) d,a ( x , x (cid:48) )) ( x , x (cid:48) ) -a.e. by virtue of the Proposition 4-4. Hence kernelsevolve in the limit of large width (at least, for sufficiently small η ∗ a/w ).If we follow the lines of the previous section, we will get a limit dynamics which is not closed: ∆ ˜ f ( k ) ∞ ( x ) = − ˆ η ∗ w ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w ) (cid:16) ˜ K ( k ) w, ∞ ( x , x ( k ) w )++ O ˆ η ∗ a/w → (ˆ η ∗ w ∆ ˜ K ( k ) , (cid:48) ww, ∞ ( x , x ( k ) w ) + ˆ η ∗ a ∆ ˜ K ( k ) , (cid:48) wa, ∞ ( x , x ( k ) w )) (cid:17) −− ˆ η ∗ a ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) (cid:16) ˜ K ( k ) a, ∞ ( x , x ( k ) a ) + O ˆ η ∗ w → (ˆ η ∗ w ∆ ˜ K ( k ) , (cid:48) aw, ∞ ( x , x ( k ) a )) (cid:17) , (70) ˜ f ( k +1) ∞ ( x ) = ˜ f ( k ) ∞ ( x ) + ∆ ˜ f ( k ) ∞ ( x ) , ˜ f (0) ∞ ( x ) = 0 , (71) ∇ (0) f ∞ (cid:96) ( x , y ) = − y [ N (0 , σ (0) , ( x )) y < for / q σ > ; − y σ ∗ N (0 ,σ (0) , ( x )) y ) for / q σ = 0 ; − y/ for / q σ < ; (72)17 ( k +1) f ∞ (cid:96) ( x , y ) = − y [ ˜ f ( k +1) ∞ ( x ) y < for q σ > ; − y f ( k +1) ∞ ( x ) y ) for q σ = 0 ; − y/ for q σ < ; ∀ k ≥ . (73)The reason for this is non-stationarity of kernels. As a workaround we consider a measure in theweight space: µ ( k ) d = 1 d d (cid:88) r =1 δ ˆ a ( k ) r ⊗ δ ˆ w ( k ) r . (74)Recall the stochastic gradient descent dynamics: ∆ˆ a ( k ) r = − ˆ η ∗ a σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) , ˆ a (0) r ∼ N (0 , , (75) ∆ ˆ w ( k ) r = − ˆ η ∗ w σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) ˆ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) x ( k ) w , ˆ w (0) r ∼ N (0 , I d x ) . (76)Here we have replaced ˆ η a/w σ with ˆ η ∗ a/w σ ∗ , because q σ + ˜ q = 0 . Similar to [4, 6], this dynamics canbe expressed in terms of the measure defined above: µ ( k +1) d = µ ( k ) d + div( µ ( k ) d ∆ θ ( k ) d ) , µ (0) d = 1 d d (cid:88) r =1 δ ˆ θ (0) r , ˆ θ (0) r ∼ N (0 , I d x ) ∀ r ∈ [ d ] , (77) ∆ θ ( k ) d (ˆ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (78) f ( k ) d ( x ) = σ ∗ d q σ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) d ( d ˆ a, d ˆ w ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (79)We rewrite the last equation in terms of ˜ f ( k ) d ( x ) = d − − q σ f ( k ) d ( x ) : ˜ f ( k ) d ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) d ( d ˆ a, d ˆ w ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y d q σ ˜ f ( k ) d ( x ) y ) ∀ k ≥ . (80)This dynamics is closed. Taking the limit d → ∞ yields: µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) ∞ ) , µ (0) ∞ = N (0 , I d x ) , (81) ∆ θ ( k ) ∞ (ˆ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (82) ˜ f ( k ) ∞ ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) , (83) ∇ (0) f ∞ (cid:96) ( x , y ) = − y [ N (0 , σ (0) , ( x )) y < for / q σ > ; − y σ ∗ N (0 ,σ (0) , ( x )) y ) for / q σ = 0 ; − y/ for / q σ < ; (84) ∇ ( k +1) f ∞ (cid:96) ( x , y ) = − y [ ˜ f ( k +1) ∞ ( x ) y < for q σ > ; − y f ( k +1) ∞ ( x ) y ) for q σ = 0 ; − y/ for q σ < ; ∀ k ≥ . (85)Since proportionality factors σ ∗ and ˆ η ∗ a/w are assumed to be fixed, choosing q σ is sufficient to definethe dynamics. Signs of exponents / q σ and q σ give 5 distinct limit dynamics. Together with8 limit dynamics for constant normalized kernels case, this gives 13 distinct limit dynamics, eachcorresponding to a region in the band of a dynamical stability (Figure 1, left).As was noted earlier, only the sign of logits matters, and our ˜ f ( k ) d preserve the sign for any d : ∀ x lim d →∞ sign( f ( k ) d ( x )) = lim d →∞ sign( ˜ f ( k ) d ( x )) = sign( ˜ f ( k ) ∞ ( x )) .18 .2.1 MF limit model We state here a special case of the mean-field scaling ( q σ = − , ˜ q = 1 , see [4] or [6]) explicitly.Similar to NTK case, since q σ = 0 we can omit tildas. This results in the following limitdynamics: µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) ∞ ) , µ (0) ∞ = N (0 , I d x ) , (86) ∆ θ ( k ) ∞ (ˆ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (87) f ( k ) ∞ ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) , ∇ ( k ) f ∞ (cid:96) ( x , y ) = − y f ( k ) ∞ ( x ) y ) ∀ k ≥ . (88) B.2.2 Sym-default limit model
Another special case which deserves explicit formulation is what we have called a "sym-default"limit model. The corresponding scaling is: q σ = − / , ˜ q = 1 / . The resulting limit dynamics is thefollowing: µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) ∞ ) , µ (0) ∞ = N (0 , I d x ) , (89) ∆ θ ( k ) ∞ (ˆ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (90) ˜ f ( k ) ∞ ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) , (91) ∇ (0) f ∞ (cid:96) ( x , y ) = − y σ ∗ N (0 , σ (0) , ( x )) y ) , (92) ∇ ( k +1) f ∞ (cid:96) ( x , y ) = − y [ ˜ f ( k +1) ∞ ( x ) y < ∀ k ≥ . (93) C Default scaling
Consider the special case of the default scaling: q σ = − / , ˜ q a = 1 , ˜ q w = 0 . Then correspondingdynamics can be written as follows: ∆ˆ a ( k ) r = − ˆ η ∗ a σ ∗ d / ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) , ˆ a (0) r ∼ N (0 , , (94) ∆ ˆ w ( k ) r = − ˆ η ∗ w σ ∗ d − / ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) ˆ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) x ( k ) w , ˆ w (0) r ∼ N (0 , I d x ) , (95) f ( k ) d ( x ) = σ ∗ d − / d (cid:88) r =1 ˆ a ( k ) r φ ( ˆ w ( k ) ,Tr x ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (96)As one can see, increments of output layer weights ∆ˆ a ( k ) r diverge with d . We introduce theirnormalized versions: ∆˜ a ( k ) r = d − / ∆ˆ a ( k ) r . Similarly, we normalize output layer weights themselves: ˜ a ( k ) r = d − / ˆ a ( k ) r . Then the dynamics transforms to: ∆˜ a ( k ) r = − ˆ η ∗ a σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) , ˜ a (0) r ∼ N (0 , d − ) , (97) ∆ ˆ w ( k ) r = − ˆ η ∗ w σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) ˜ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) x ( k ) w , ˆ w (0) r ∼ N (0 , I d x ) , (98)19 ( k ) d ( x ) = σ ∗ d (cid:88) r =1 ˜ a ( k ) r φ ( ˆ w ( k ) ,Tr x ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (99)Similar to SM B.2, we have to introduce a weight-space measure in order to take a limit of d → ∞ : µ ( k ) d = 1 d d (cid:88) r =1 δ ˜ a ( k ) r ⊗ δ ˆ w ( k ) r . (100)In terms of the measure the dynamics is expressed then as follows: µ ( k +1) d = µ ( k ) d + div( µ ( k ) d ∆ θ ( k ) d ) , (101) µ (0) d = 1 d d (cid:88) r =1 δ ˜ a (0) r ⊗ δ ˆ w (0) r , ˜ a (0) r ∼ N (0 , d − ) , ˆ w (0) r ∼ N (0 , I d x ) ∀ r ∈ [ d ] , (102) ∆ θ ( k ) d (˜ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )˜ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (103) f ( k ) d ( x ) = σ ∗ d (cid:90) ˜ aφ ( ˆ w T x ) µ ( k ) d ( d ˜ a, d ˆ w ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (104)We rewrite the last equation in terms of ˜ f ( k ) d ( x ) = d − f ( k ) d ( x ) : ˜ f ( k ) d ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) d ( d ˆ a, d ˆ w ) , ∇ ( k ) f d (cid:96) ( x , y ) = − y d ˜ f ( k ) d ( x ) y ) ∀ k ≥ . (105)A limit dynamics then takes the following form: µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) d ) , µ (0) ∞ = δ ⊗ N (0 , I d x ) (106) ∆ θ ( k ) ∞ (˜ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w )˜ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (107) ∇ (0) f ∞ (cid:96) ( x , y ) = − y σ ∗ N (0 , σ (0) , ( x )) y ) , (108) ˜ f ( k ) ∞ ( x ) = σ ∗ (cid:90) ˜ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˜ a, d ˆ w ) , ∇ ( k +1) f ∞ (cid:96) ( x , y ) = − y [ ˜ f ( k +1) ∞ ( x ) y < ∀ k ≥ . (109)As one can notice, the only difference between this limit dynamics and the limit dynamics ofsym-default scaling (SM B.2.2) is the initial measure.We now check the Condition 3. First of all, by the Central Limit Theorem, f (0) d ( x ) = Θ d →∞ (1) ,hence the first point of Condition 3 holds. As for kernels, we have: K ( k ) a,d ( x , x (cid:48) ) = σ ∗ , d (cid:88) r =1 φ ( ˆ w ( k ) ,Tr x ) φ ( ˆ w ( k ) ,Tr x (cid:48) ) , (110) K ( k ) w,d ( x , x (cid:48) ) = σ ∗ , d − d (cid:88) r =1 | ˆ a ( k ) r | φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) x T x (cid:48) . (111)20e see that while K (0) w,d converges to a constant due to the Law of Large Numbers, K (0) a,d diverges as d → ∞ . This violates the second statement of Condition 3, and the third as well, since f (0) ∞ is finite.Consider now kernel increments: ∆ K ( k ) , (cid:48) aw,d ( x , x (cid:48) ) = − σ ∗ , d − / d (cid:88) r =1 (cid:16) φ ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) + φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ ( ˆ w ( k ) ,Tr x (cid:48) ) (cid:17) ×× ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )ˆ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w )( x + x (cid:48) ) T x ( k ) w , (112) ∆ K ( k ) , (cid:48) ww,d ( x , x (cid:48) ) = − σ ∗ , d − / d (cid:88) r =1 | ˆ a ( k ) r | (cid:16) φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48)(cid:48) ( ˆ w ( k ) ,Tr x (cid:48) )++ φ (cid:48)(cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) (cid:17) x T x (cid:48) ×× ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )ˆ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w )( x + x (cid:48) ) T x ( k ) w , (113) ∆ K ( k ) , (cid:48) wa,d ( x , x (cid:48) ) = − σ ∗ , d − / d (cid:88) r =1 a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ) φ (cid:48) ( ˆ w ( k ) ,Tr x (cid:48) ) ×× ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) . (114)For k = 0 terms inside sums of each increment have zero expectations. Hence the Central LimitTheorem can be used here. We get: ∆ K (0) , (cid:48) aw,d = Θ d →∞ (1) , ∆ K (0) , (cid:48) ww,d = Θ d →∞ ( d − ) , ∆ K (0) , (cid:48) wa,d =Θ d →∞ (1) . Since K (0) a,d = Θ d →∞ ( d ) , K (0) w,d = Θ d →∞ (1) , the last statement of Condition 3 isviolated as well. D Initialization-corrected mean-field (IC-MF) limit
Here we consider the same training dynamics as for the mean-field scaling (see SM B.2), but with amodified model definition: ∆ˆ a ( k ) r = − ˆ η ∗ a σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w ( k ) ,Tr x ( k ) a ) , ˆ a (0) r ∼ N (0 , , (115) ∆ ˆ w ( k ) r = − ˆ η ∗ w σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w ) ˆ a ( k ) r φ (cid:48) ( ˆ w ( k ) ,Tr x ( k ) w ) x ( k ) w , ˆ w (0) r ∼ N (0 , I d x ) . (116) f ( k ) d ( x ) = σ ∗ d − d (cid:88) r =1 ˆ a ( k ) r φ ( ˆ w ( k ) ,Tr x ) + σ ∗ d − / d (cid:88) r =1 ˆ a (0) r φ ( ˆ w (0) ,Tr x ) , (117) ∇ ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (118)Similar to the mean-field case (SM B.2), we rewrite the dynamics above in terms of the weight-spacemeasure: µ ( k +1) d = µ ( k ) d + div( µ ( k ) d ∆ θ ( k ) d ) , µ (0) d = 1 d d (cid:88) r =1 δ ˆ θ (0) r , ˆ θ (0) r ∼ N (0 , I d x ) ∀ r ∈ [ d ] , (119) ∆ θ ( k ) d (ˆ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f d (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (120) f ( k ) d ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) d ( d ˆ a, d ˆ w ) + σ ∗ d / (cid:90) ˆ aφ ( ˆ w T x ) µ (0) d ( d ˆ a, d ˆ w ) , (121)21 ( k ) f d (cid:96) ( x , y ) = − y f ( k ) d ( x ) y ) ∀ k ≥ . (122)Note that here f ( k ) d stays finite in the limit of d → ∞ for any k ≥ . Hence taking the limit d → ∞ yields: µ ( k +1) ∞ = µ ( k ) ∞ + div( µ ( k ) ∞ ∆ θ ( k ) ∞ ) , µ (0) ∞ = N (0 , I d x ) , (123) ∆ θ ( k ) ∞ (ˆ a, ˆ w ) == − [ˆ η ∗ a σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) a , y ( k ) a ) φ ( ˆ w T x ( k ) a ) , ˆ η ∗ w σ ∗ ∇ ( k ) f ∞ (cid:96) ( x ( k ) w , y ( k ) w )ˆ aφ (cid:48) ( ˆ w T x ( k ) w ) x ( k ) ,Tw ] T , (124) f ( k ) ∞ ( x ) = σ ∗ (cid:90) ˆ aφ ( ˆ w T x ) µ ( k ) ∞ ( d ˆ a, d ˆ w ) + σ ∗ N (0 , σ (0) , ( x )) , (125) ∇ ( k ) f ∞ (cid:96) ( x , y ) = − y f ( k ) ∞ ( x ) y ) ∀ k ≥ . (126) E Experimental details
We perform our experiments on a feed-forward fully-connected network with a single hidden layerwith no biases. We learn our network as a binary classifier on a subset of the CIFAR2 dataset (whichis a dataset of first two classes of CIFAR10 ) of size 1024. We report results using a test set from thesame dataset of size 2000. We do not do a hyperparameter search, for this reason we do not use avalidation set.We train our network for 2000 training steps to minimize the binary cross-entropy loss. We use afull-batch GD as an optimization algorithm. We repeat our experiments for 10 random seeds andreport mean and deviations in plots for logits and kernels (e.g. Figure 1, left). For plots of theKL-divergence, we use logits from these 10 random seeds to fit a single gaussian. Where necessary,we estimate data expectations (e.g. E x ∼D | f ( x ) | ) using 10 samples from the test dataset.We experiment with other setups (i.e. using a mini-batch gradient estimation instead of exact one, alarger train dataset, a multi-class classification) in SM F. All experiments were conducted on a singleNVIDIA GeForce GTX 1080 Ti GPU using the PyTorch framework [15]. Our code is available on-line: https://github.com/deepmipt/research/tree/master/Infinite_Width_Limits_of_Neural_Classifiers .Although our analysis assumes initializing variables with samples from a gaussian, nothing changes ifwe sample σξ instead, where ξ can be any symmetric random variable with a distribution independenton hyperparameters.In our experiments, we took a network of width d ∗ = 2 = 128 and apply the Kaiming He uniforminitialization [9] to its layers; we call this network a reference network. According to the KaimingHe initialization strategy, initial weights have a zero mean and a standard deviation σ ∗ ∝ ( d ∗ ) − / for the output layer, while the standard deviation of the input layer does not depend on the referencewidth d ∗ . For this network we take learning rates in the original parameterization η ∗ a = η ∗ w = 0 . .After that, we scale its initial weights and learning rates with width d according to a scaling at hand: σ = σ ∗ (cid:18) dd ∗ (cid:19) q σ , ˆ η a/w = ˆ η ∗ a/w (cid:18) dd ∗ (cid:19) ˜ q a/w . Note that we have assumed σ w = 1 . By definition, ˆ η a/w = η a/w /σ a/w ; this implies: η a = η ∗ a (cid:16) σσ ∗ (cid:17) (cid:18) dd ∗ (cid:19) ˜ q a = η ∗ a (cid:18) dd ∗ (cid:19) ˜ q a +2 q σ , η w = η ∗ w (cid:18) dd ∗ (cid:19) ˜ q w . CIFAR10 can be downloaded at Experiments for other setups
Although plots provided in the main body represent the full-batch GD on a subset of CIFAR2, wehave experimented with other setups as well. In particular, we have varied the batch size and the sizeof the train dataset. Results are shown in Figures 3-7. Differences are quantitative and marginal.Figure 3: Test accuracy of different limit models, as well as of the reference model.
Setup:
We train aone hidden layer network on subsets of the CIFAR2 dataset of different sizes with SGD with varyingbatch sizes. 23igure 4: Mean kernel diagonals E x ∼D (ˆ η ∗ a K a,d ( x , x ) + ˆ η ∗ w K w,d ( x , x )) of different limit models, aswell as of the reference model. Setup:
We train a one hidden layer network on subsets of the CIFAR2dataset of different sizes with SGD with varying batch sizes. Data expectations are estimated with 10test data samples.Figure 5: Mean absolute logits E x ∼D | f ( x ) | of different limit models, as well as of the referencemodel. Setup:
We train a one hidden layer network on subsets of the CIFAR2 dataset of differentsizes with SGD with varying batch sizes. Data expectations are estimated with 10 test data samples.24igure 6: Mean absolute logits relative to kernel diagonals E x ∼D | f d ( x ) / (ˆ η ∗ a K a,d ( x , x ) +ˆ η ∗ w K w,d ( x , x )) | of different limit models, as well as of the reference model. Setup:
We train aone hidden layer network on subsets of the CIFAR2 dataset of different sizes with SGD with varyingbatch sizes. Data expectations are estimated with 10 test data samples.Figure 7: KL-divergence of different limit models relative to a reference model.