Predicting Training Time Without Training
Luca Zancato, Alessandro Achille, Avinash Ravichandran, Rahul Bhotika, Stefano Soatto
PPredicting Training Time Without Training
Luca Zancato , Alessandro Achille Avinash Ravichandran Rahul Bhotika Stefano Soatto University of Padova Amazon Web Services [email protected] {aachille,ravinash,bhotikar,soattos}@amazon.com Abstract
We tackle the problem of predicting the number of optimization steps that a pre-trained deep network needs to converge to a given value of the loss function. Todo so, we leverage the fact that the training dynamics of a deep network duringfine-tuning are well approximated by those of a linearized model. This allows us toapproximate the training loss and accuracy at any point during training by solvinga low-dimensional Stochastic Differential Equation (SDE) in function space. Usingthis result, we are able to predict the time it takes for Stochastic Gradient Descent(SGD) to fine-tune a model to a given loss without having to perform any training.In our experiments, we are able to predict training time of a ResNet within a20% error margin on a variety of datasets and hyper-parameters, at a 30 to 45-foldreduction in cost compared to actual training. We also discuss how to further reducethe computational and memory cost of our method, and in particular we show thatby exploiting the spectral properties of the gradients’ matrix it is possible predicttraining time on a large dataset while processing only a subset of the samples.
Say you are a researcher with many more ideas than available time and compute resources to testthem. You are pondering to launch thousands of experiments but, as the deadline approaches, youwonder whether they will finish in time, and before your computational budget is exhausted. Couldyou predict the time it takes for a network to converge, before even starting to train it?We look to efficiently estimate the number of training steps a Deep Neural Network (DNN) needsto converge to a given value of the loss function, without actually having to train the network. Thisproblem has received little attention thus far, possibly due to the fact that the initial training dynamicsof a randomly initialized DNN are highly non-trivial to characterize and analyze. However, in mostpractical applications, it is common to not start from scratch, but from a pre-trained model. This maysimplify the analysis, since the final solution obtained by fine-tuning is typically not too far from theinitial solution obtained after pre-training. In fact, it is known that the dynamics of overparametrizedDNNs [9, 31, 2] during fine-tuning tends to be more predictable and close to convex [24].We therefore characterize the training dynamics of a pre-trained network and provide a computation-ally efficient procedure to estimate the expected profile of the loss curve over time. In particular, weprovide qualitative interpretation and quantitative prediction of the convergence speed of a DNN as afunction of the network pre-training, the target task, and the optimization hyper-parameters.We use a linearized version of the DNN model around pre-trained weights to study its actual dynamics.In [20] a similar technique is used to describe the learning trajectories of randomly initialized wideneural networks. Such an approach is inspired by the Neural Tangent Kernel (NTK) for infinitelywide networks [14]. While we note that NTK theory may not correctly predict the dynamics of real(finite size) randomly initialized networks [12], we show that our linearized approach can be extendedto fine-tuning of real networks in a similar vein to [24]. In order to predict fine-tuning Training Time(TT) without training we introduce a Stochastic Differential Equation (SDE) (similar to [13]) to
Preprint. Under review. a r X i v : . [ c s . L G ] A ug pproximate the behavior of SGD: we do so for a linearized DNN and in function space rather thanin weight space. That is, rather than trying to predict the evolution of the weights of the network (a D -dimensional vector), we aim to predict the evolution of the outputs of the network on the trainingset (a N × C -dimensional vector, where N is the size of the dataset and C the number of network’soutputs). This drastically reduces the dimensionality of the problem for over-parametrized networks(that is, when N C (cid:28) D ).A possible limiting factor of our approach is that the memory requirement to predict the dynamicsscales as O ( DC N ) . This would rapidly become infeasible for datasets of moderate size andfor real architectures ( D is in the order of millions). To mitigate this, we show that we can userandom projections to restrict to a much smaller D -dimensional subspace with only minimal lossin prediction accuracy. We also show how to estimate Training Time using a small subset of N samples, which reduces the total complexity to O ( D C N ) . We do this by exploiting the spectralproperties of the Gram matrix of the gradients. Under mild assumptions the same tools can be usedto estimate Training Time on a larger dataset without actually seeing the data.To summarize, our main contributions are:(i) We present both a qualitative and quantitative analysis of the fine-tuning Training Time as afunction of the Gram-Matrix Θ of the gradients at initialization (empirical NTK matrix).(ii) We show how to reduce the cost of estimating the matrix Θ using random projections of thegradients, which makes the method efficient for common architectures and large datasets.(iii) We introduce a method to estimate how much longer a network will need to train if we increasethe size of the dataset without actually having to see the data (under the hypothesis that newdata is sampled from the same distribution).(iv) We test the accuracy of our predictions on off-the-shelf state-of-the-art models trained on realdatasets. We are able to predict the correct training time within a 20% error with 95% confidenceover several different datasets and hyperparameters at only a small fraction of the time it wouldrequire to actually run the training (30-45x faster in our experiments). Predicting the training time of a state-of-the-art architecture on large scale datasets is a relativelyunderstudied topic. In this direction, Justus et al. [15] try to estimate the wall-clock time requiredfor a forward and backward pass on given hardware. We focus instead on a complementary aspect:estimating the number of fine-tuning steps necessary for the loss to converge below a given threshold.Once this has been estimated we can combine it with the average time for the forward and backwardpass to get a final estimate of the wall clock time to fine-tune a DNN model without training it.Hence, we are interested in predicting the learning dynamics of a pre-trained DNN trained with eitherGradient Descent (GD) or Stochastic Gradient Descent (SGD). While different results are knownto describe training dynamics under a variety of assumptions (e.g. [16, 28, 26, 6]), in the followingwe are mainly interested on recent developments which describe the optimization dynamics of aDNN using a linearization approach. Several works [14, 19, 10] suggest that in the over-parametrizedregime wide DNNs behave similar to linear models, and in particular they are fully characterized bythe Gram-Matrix of the gradients, also known as empirical Neural Tangent Kernel (NTK).Under these assumptions, [14, 3] derive a simple connection between training time and spectraldecomposition of the NTK matrix. However, their results are limited to Gradient Descend dynamicsand to simple architectures which are not directly applicable to real scenarios. In particular, theirarguments hinge on the assumption of using a randomly initialized very wide two-layer or infinitelywide neural network [3, 11, 22]. We take this direction a step further, providing a unified frameworkwhich allows us to describe training time for both SGD and GD on common architectures.Again, we rely on a linear approximation of the model, but while the practical validity of suchlinear approximation for randomly initialized state-of-the-art architectures (such as ResNets) is stilldiscussed [12], we follow Mu et al. [24] and argue that the fine-tuning dynamics of over-parametrizedDNNs can be closely described by a linearization. We expect such an approximation to hold true sincethe network does not move much in parameters space during fine-tuning and over-parametrization leads to smooth and regular loss function around the pre-trained weights [9, 31, 2, 21]. Under this2
20 40 60 80 100 120 140 160
Real Training Time P r e d i c t e d T r a i n i n g T i m e Perfect prediction+13% error-13% errorCIFAR10 slope 0.93CIFAR100 slope 0.94CUB200 slope 0.97Aircrafts slope 1.09Mit67 slope 0.89Surfaces slope 0.90Cars slope 1.00 (a) Training with Gradient Descent.
Real Training Time P r e d i c t e d T r a i n i n g T i m e Perfect prediction+20% error-20% errorCIFAR10 slope 0.87CIFAR100 slope 0.78CUB200 slope 0.91Aircrafts slope 0.78Mit67 slope 0.87Surfaces slope 0.78Cars slope 0.90 (b) Training with SGD.
Figure 1:
Training time prediction (
Scatter plots ofthe predicted time vs the actual training time when fine-tuning a ResNet-18 pre-trained on ImageNeton several tasks. Each task is obtained by randomly sampling a subset of five classes with 150 images(when possible) each from one popular dataset with different hyperparameters (batch size, learningrate). The closer the scatter plots to the bisector the better the TT estimate. Our prediction is (a) within 13% of the real training time 95% of the times when using GD and (b) within 20% of the realtraining time when using SGD.premise, we tackle both GD and SGD in an unified framework and build on [13] to model trainingof a linear model using a Stochastic Differential Equation in function space. We show that, as alsohypothesized by [24], linearization can provide an accurate approximation of fine-tuning dynamicsand therefore can be used for training time prediction.
In this section we look at how to efficiently approximate the training time of a DNN without actualtraining. By
Training Time (TT) we mean the number of optimization steps – of either GradientDescent (GD) or Stochastic Gradient Descent (SGD) – needed to bring the loss on the training setbelow a certain threshold.We start by introducing our main tool. Let f w ( x ) denote the output of the network, where w denotesthe weights of the network and x ∈ R d denotes its input (e.g., an image). Let w be the weightconfiguration after pre-training. We assume that when fine-tuning a pre-trained network the solutionremains close to pre-trained weights w [24, 9, 31, 2]. Under this assumption – which we discussfurther in Section 6 – we can faithfully approximate the network with its Taylor expansion around w [20]. Let w t be the fine-tuned weights at time t . Using big-O notation and f t ≡ f w t , we have: f t ( x ) = f ( x ) + ∇ w f ( x ) | w = w ( w t − w ) + O ( (cid:107) w t − w (cid:107) ) We now want to use this approximation to characterize the training dynamics of the network duringfine-tuning to estimate TT. In such theoretical analyses [14, 20, 3] it is common to assume that thenetwork is trained with Gradient Descent (GD) rather than Stochastic Gradient Descent, and in thelimit of a small learning rate. In this limit, the dynamics are approximated by the gradient flowdifferential equation ˙ w t = − η ∇ w t L [14, 20] where η denotes the learning rate and L ( w ) denotesthe loss function L ( w ) = (cid:80) Ni =1 (cid:96) ( y i , f w ( x i )) . , where (cid:96) is the per-sample loss function (e.g. Cross-Entropy). This approach however has two main drawbacks. First, it does not properly approximateStochastic Gradient Descent, as it ignores the effect of the gradient noise on the dynamics, whichaffects both training time and generalization. Second, the differential equation involves the weightsof the model, which live in a very high dimensional space thus making finding numerical solutions tothe equation not tractable.To address both problems, building on top of [13] in the Supplementary we prove the following result. Proposition 1
In the limit of small learning rate η , the output on the training set of a linearizednetwork f lint trained with SGD evolves according to the following Stochastic Differential Equation(SDE): df lin t ( X ) = − η Θ ∇ f lin t ( X ) L t dt (cid:124) (cid:123)(cid:122) (cid:125) deterministic part + η (cid:112) | B | ∇ w f lin ( X )Σ ( f lin t ( X )) dn (cid:124) (cid:123)(cid:122) (cid:125) stochastic part , (1)3
20 40 60 80 100
Iterations A cc u r a c y ODE vs SDE approximationSGDSDEODE
Iterations T r a i n E rr ELR effects
ELR:0.001ELR:0.005ELR:0.010ELR:0.050ELR:0.100
Figure 2: (Left) ODE vs. SDE.
ODE approximation may not be well suited to describe the actualnon-linear SGD dynamics (high learning rates regime). (Right) Fine-tuning with the same ELRhave similar curves . We fine-tune an ImageNet pre-trained network on MIT-67 with differentcombinations of learning rates and momentum coefficients. We note that as long as the effectivelearning rate is the same, the loss curves are also similar. where X is the set of training images, | B | the batch-size and dn is a D -dimensional Brownian motion.We have defined the Gram gradients matrix Θ [14, 27] (i.e., the empirical Neural Tangent Kernelmatrix) and the covariance matrix Σ of the gradients as follows: Θ := ∇ w f ( X ) ∇ w f ( X ) T , (2) Σ( f lin t ( X )) := E (cid:2) ( g i ∇ f lin t ( x i ) L ) ⊗ ( g i ∇ f lin t ( x i ) L ) (cid:3) − E (cid:2) g i ∇ f lin t ( x i ) L (cid:3) ⊗ E (cid:2) g i ∇ f lin t ( x i ) L (cid:3) . (3) where g i ≡ ∇ w f ( x i ) . Note both Θ and Σ only require gradients w.r.t. parameters computed atinitialization. The first term of eq. (1) is an ordinary differential equation (ODE) describing the deterministic part ofthe optimization, while the second stochastic term accounts for the noise. In Figure 2 (left) we showthe qualitative different behaviour of the solution to the deterministic part of eq. (1) and the completeSDE eq. (1). While several related results are known in the literature for the dynamics of the networkin weight space [7], note that eq. (1) completely characterizes the training dynamics of the linearizedmodel by looking at the evolution of the output f lint ( X ) of the model on the training samples – a N × C -dimensional vector – rather than looking at the evolution of the weights w t – a D -dimensionalvector. When the number of data points is much smaller than the number of weights (which are in theorder of millions for ResNets), this can result in a drastic dimensionality reduction, which allows easyestimation of the solution to eq. (1). Solving eq. (1) still comes with some challenges, particularly incomputing Θ efficiently on large datasets and architectures. We tackle these in Section 4. Before that,we take a look at how different hyper-parameters and different pre-trainings affect the training timeof a DNN on a given task. From Proposition 1 we can gauge how hyper-parameters will affect theoptimization process of the linearized model and, by proxy, of the original model it approximates.One thing that should be noted is that Proposition 1 assumes the network is trained with momentum m = 0 . Using a non-zero momentum leads to a second order differential equation in weight space,that is not captured by Proposition 1. We can however, introduce heuristics to handle the effect ofmomentum: Smith et al. [28] note that the momentum acts on the stochastic part shrinking it by afactor (cid:112) / (1 − m ) . Meanwhile, under the assumptions we used in Proposition 1 (small learningrate), we can show (see Supplementary Material) the main effect of momentum on the deterministicpart is to re-scale the learning rates by a factor / (1 − m ) . Given these results, we define the effectivelearning rate (ELR) ˆ η = η/ (1 − m ) and claim that, in first approximation, we can simulate the effectof momentum by using ˆ η instead of η in eq. (1). In particular, models with different learning rates andmomentum coefficients will have similar (up to noise) dynamics (and hence training time) as long asthe effective learning rate ˆ η remains the same. In Figure 2 we show empirically that indeed sameeffective learning rate implies similar loss curve. That similar effective learning rate gives similar testperformance has also been observed in [21, 28]. Batch size.
The batch size appears only in the stochastic part of the equation, its main effect is todecrease the scale of the SDE noise term. In particular, when the batch size goes to infinity | B | → ∞ a) Features and Gradients clustering. (b) Trajectory clustering. Figure 3:
Are gradients good descriptors to cluster data by semantics and training time? (a)
Features vs Gradients clustering. (Right) t-SNE plot of the first five principal components of thegradients of each sample in a subset of CIFAR-10 with 3 classes. Colors correspond to the sampleclass. We observe that the first 5 principal components are enough to separate the data by class. ByProposition 2 this implies faster training time. (
Left ) In the same setting as before, t-SNE plot of thefeatures using the first 5 components of PCA. We observe that gradients separate the classes betterthan the features. (b) t-SNE on predicted trajectories
To see if gradients are good descriptors ofboth semantics and training time we use gradients to predict linearized trajectories: we cluster thetrajectories using t-SNE and we color each point by (left) class and (right) training time. We observethat: clusters split trajectories according both to labels (left) and training time (right) . Interestinglyinside each class there are clusters of points that may converge at different speed.we recover the deterministic gradient flow also studied by [20]. Note that we need the batch size | B | to go to infinity, rather than being as large as the dataset since we assumed random batch samplingwith replacement. If we assume extraction without replacement the stochasticity is annihilated assoon as | B | = N (see [7] for a more in depth discussion). We now use the SDE in eq. (1) to analyze how the combination of different pre-trainings of the model– that is, different w ’s – and different tasks affect the training time. In particular, we show that anecessary condition for fast convergence is that the gradients after pre-training cluster well withrespect to the labels. We conduct this analysis for a binary classification task with y i = ± , but theextension is straightforward for multi-class classification, under the simplifying assumptions thatwe are operating in the limit of large batch size (GD) so that only the deterministic part of eq. (1)remains. Under these assumptions, eq. (1) can be solved analitically and the loss of the linearizedmodel at time t can be written in closed form as (see Supplementary Material): L t = ( Y − f ( X )) T e − η Θ t ( Y − f ( X )) (4)The following characterization can easily be obtained using an eigen-decomposition of the matrix Θ . Proposition 2
Let S = ∇ w f w ( X ) T ∇ w f w ( X ) be the second moment matrix of the gradients andlet S = U Σ U T be the uncentered PCA of the gradients, where Σ = diag( λ , . . . , λ n , , . . . , isa D × D diagonal matrix, n ≤ min( N, D ) is the rank of S and λ i are the eigenvalues sorted indescending order. Then we have L t = D (cid:88) k =1 e − ηλ k t ( δ y · v k ) , (5) where λ k v k = ( g i · u k ) Ni =1 is the N -dimensional vector containing the value of the k -th principalcomponent of gradients g i and δ y := Y − f ( X ) . Training speed and gradient clustering.
We can give the following intuitive interpretation: considerthe gradient vector g i as a representation of the sample x i . If the first principal components of g i are sufficient to separate the classes (i.e., cluster them), then convergence is faster (see Figure 3).Conversely, if we need to use the higher components (associated to small λ k ) to separate the data,then convergence will be exponentially slower. Arora et al. [3] also use the eigen-decomposition of Θ Iterations T r a i n L o ss Trajectory approximationsFull kernelReduced kernelActual dynamics N . . . . . . || Θ − ˆ Θ || F || Θ || F C o m pu t a t i o n a l T i m e ( s ) Approximated vs True kernel
TrueApproximated Eigenvalues index E i g e n v a l u e s Power law for different dataset sizes
N 30N 60N 150N 300N 600N 1500N 2100
Figure 4: (Left)
Actual fine-tuning of a DNN with GD compared to the numerical solution of eq. (1)and the solution using an approximated Θ . The approximated Θ can faithfully describe fine-tuningdynamics while being twice as fast to compute and 100 times smaller to be stored. (Center) Relativedifference in Frobenius norm of the real and approximated Θ as the dataset size varies (red), andtheir computational time (blue). Right : Eigen-spectrum of Θ computed on subsets of MIT-67 ofincreasing size. Note the convergence to a common power law (i.e., a line in log-log scale).to explain the slower convergence observed for a randomly initialized two-layer network trained withrandom labels. This is straightforward since the projection of a random vector will be uniform on alleigenvectors, rather than concentrated on the first few, leading to slower convergence. However, wenote that the exponential dynamics predicted by [3] do not hold for more general networks trainedfrom scratch [30] (see Section 6). In particular, eq. (5) mandates that the loss curve is always convex(it is sum of convex functions), which may not be the case for deep networks trained from scratch. In Proposition 2 we have shown a closed form solution to the SDE in eq. (1) in the limit of large batchsize, and for the MSE loss. Unfortunately, in general eq. (1) does not have a closed form expressionwhen using the cross-entropy loss [20]. A numerical solution is however possible, enabled by the factthat we describe the network training in function space, which is much smaller than weight spacefor over-parametrized models. The main computational cost is to create the matrix Θ in eq. (1) –which has cost O ( DC N ) – and to compute the noise in the stochastic term. Here we show how toreduce the cost of Θ to O ( D C N ) for D (cid:28) D using a random projection approximation. Then,we propose a fast approximation for the stochastic part. Finally, we describe how to reduce the costin N by using only a subset N (cid:48) < N of samples to predict training time. Random projection.
To keep the notation uncluttered, here we assume w.l.o.g. C = 1 . In this casethe matrix Θ contains N pairwise dot-products of the gradients (a D -dimensional vector) for eachof the N training samples (see eq. 2). Since D can be very large (in the order of millions) storingand multiplying all gradients can be expensive as N grows. Hence, we look at a dimensionalityreduction technique. The optimal dimensionality reduction that preserves the dot-product is obtainedby projecting on the first principal components of SVD, which however are themselves expensive toobtain. A simpler technique is to project the gradients on a set of D (cid:48) standard Gaussian random vectors:it is known that such random projections preserve (in expectation) pairwise product [5, 1] betweenvectors, and hence allow us to reconstruct the Gram matrix while storing only D (cid:48) -dimensional vector,with D (cid:48) (cid:28) D . We further increase computational efficiency using multinomial random vectors{-1,0,+1} as proposed in [1] which further reduce the computational cost by avoiding floating pointmultiplications. In Figure 4 we show that the entries of Θ and its spectrum are well approximatedusing this method, while the computational time becomes much smaller. Computing the noise.
The noise covariance matrix Σ is a D × D -matrix that changes over time.Both computing it at each step and storing it is prohibitive. Estimating Σ correctly is important todescribe the dynamics of SGD [8], however we claim that a simple approximation may suffice todescribe the simpler dynamic in function space. We approximate ∇ w f lin ( X )Σ / approximating Σ with its diagonal (so that the we only need to store a D -dimensional vector). Rather than computingthe whole Σ at each step, we estimate the value of the diagonal at the beginning of the training. Then,by exploiting eq. (3), we see that the only change to Σ is due to ∇ f lin t L , whose norm decreases overtime. Therefore we use the easy-to-compute ∇ f lin t L to re-scale our initial estimate of Σ . Larger datasets.
In the MSE case from eq. (4), knowing the eigenvalues λ k and the correspondingresidual projections p k = ( δ y · v k ) we can predict in closed form the whole training curve. Is itpossible to predict λ k and p k using only a subset of the dataset? It is known [27] that the eigenvalues6f the Gram matrix of Gaussian data follow a power-law distribution of the form λ k = ck − s .Moreover, by standard concentration argument, one can prove that the eigenvalues should convergeto a given limit as the number of datapoints increases. We verify that a similar power-law andconvergence result also holds for real data (see Figure 4). Exploiting this result, we can estimate c and s from the spectrum computed on a subset of the data, and then predict the remaining eigenvalues.A similar argument holds for the projections p k , which also follow a power-law (albeit with slowerconvergence). We describe the complete estimation in the Supplementary Material. We now empirically validate the accuracy of proposition 1 in approximating the loss curve of anactual deep neural network fine-tuned on a large scale dataset. We also validate the goodness of thenumerical approximations described in Section 4. Due to the lack of a standard and well establishedbenchmark to test Training Time estimation algorithms we developed one with the main goal toclosely resemble fine-tuning common practice for a wide spectrum of different tasks.
Experimental setup.
We define training time as the first time the (smoothed) loss is below a giventhreshold. However, since different datasets converge at different speeds, the same threshold can betoo high (it is hit immediately) for some datasets, and too low for others (it may take hundreds ofepochs to be reached). To solve this, and have cleaner readings, we define a ‘normalized’ thresholdas follows: we fix the total number of fine-tuning steps T , and measure instead the first time theloss is within (cid:15) from the final value at time T . This measure takes into account the ‘asymptotic’ lossreached by the DNN within the computational budget (which may not be close to zero if the budget islow), and naturally adapts the threshold to the difficulty of the dataset. We compute both the real losscurve and the predicted training curve using Proposition 1 and compare the (cid:15) -training-time measuredon both. We report the absolute prediction error , that is | t predicted − t real | . For all the experimentswe extract 5 random classes from each dataset (Table 1) and sample 150 images (or the maximumavailable for the specific dataset). Then we fine-tuned ResNet18/34 using either GD or SGD. Accuracy of the prediction.
In Figure 1 we show TT estimates errors (for different (cid:15) ∈ { , ..., } )under a plethora of different conditions ranging from different learning rates, batch sizes, datasets andoptimization methods. For all the experiments we choose a multi-class classification problem withCross Entropy (CE) Loss unless specified otherwise, and fixed computational budget of T = 150 steps both for GD and SGD. We note that our estimates are consistently within respectively a 13%and 20% relative error around the actual training time 95% of the times.In Table 1 we describe the sensitivity of our estimates to different thresholds (cid:15) both when ourassumptions do and do not hold (high and low learning rates regimes). Note that a larger threshold (cid:15) ishit during the initial convergence phase of the network, when a small number of iterations correspondsa large change in the loss. Correspondingly, the hitting time can be measured more accurately andour errors are lower. A smaller (cid:15) depends more on correct prediction of the slower asymptotic phase,for which exact hitting time is more difficult to estimate. TT error ( (cid:15) = 1% (cid:15) = 10% (cid:15) = 40%
Lr low high low high low highCars [17] 9 18 7 8 1 0Surfaces [4] 6 13 6 7 6 3Mit67 [25] 8 10 6 8 3 1Aircrafts [23] 5 21 5 4 9 7CUB200 [29] 6 6 5 8 1 1CIFAR100 [18] 10 15 6 7 2 3CIFAR10 [18] 9 14 8 9 3 3
Table 1: Training Time estimation error for CEloss using GD for T = 150 epochs at differentthresholds (cid:15) . We compare TT estimates whenODE assumptions do and do not hold: high{0.005} and small LR {0.001, 0.0001}. ( S ) Figure 5: Wall clock time (in seconds) to com-pute TT estimate vs fine-tuning running time.We run the methods described in Section 4 bothon GPU and CPU. Training is done on GPU.
Wall-clock run-time.
In Figure 5 we show the wall-clock runtime of our training time predictionmethod compared to the time to actually train the network for T steps. Our method is 30-40 timesfaster. Moreover, we note that it can be run completely on CPU without a drastic drop in performance.This allows to cheaply estimate TT and allocate/manage resources even without access to a GPU.7 .0001 0.001 0.005 LR TT e rr o r SurfacesCIFAR100AircraftsMit67CUB200CIFAR10Cars
50 100
Batch Size TT e rr o r SurfacesCIFAR100AircraftsMit67CUB200CIFAR10
50 250 625
Dataset Size TT e rr o r SurfacesCIFAR100AircraftsMit67CUB200CIFAR10Cars
Figure 6: Average and 95% confidence intervals of TT estimate error for:
Left : GD using differentlearning rates.
Center : SGD using different batch sizes.
Right : SGD using different dataset sizes.The average is taken w.r.t. random classes with different number of samples: {10, 50, 125}
Effect of dataset distance.
We note that the average error for Surfaces (Figure 6) is uniformilyhigher than the other datasets. This may be due to the texture classification task being quite differentfrom ImageNet, on which the network is pretrained. In this case we can expect that the linearizationassumption is partially violated since the features must adjust more during fine-tuning.
Effect of hyper-parameters on prediction accuracy.
We derived Proposition 1 under severalassumptions, importantly: small learning rate and w t close to w . In Figure 6 (left) we show thatindeed increasing the learning rate decreases the accuracy of our prediction, albeit the accuracyremains good even at larger learning rates. Fine-tuning on larger dataset makes the weights movefarther away from the initialization w . In Figure 6 (right) we show that this slightly increases theprediction error. Finally, we observe in Figure 6 (center) that using a smaller batch size, which makesthe stochastic part of Proposition 1 larger also slightly increases the error. This can be ascribed tothe approximation of the noise term (Section 4). On the other hand, in Figure 2 (right) we see thatthe effect of momentum on a fine-tuned network is very well captured by the effective learning rate(Section 3.1), as long as the learning rate is reasonably small, which is the case for fine-tuning. Hencethe SDE approximation is robust to different values of the momentum. In general, we note that evenwhen our assumptions are not fully met training time can still be approximated with only a slightlyhigher error. This suggest that point-wise proximity of the training trajectory of linear and real modelsis not necessary as long as their behavior (decay-rate) is similar (see also Supplementary Material). We have shown that we can predict with a 13-20% accuracy the time that it will take for a pre-trainednetwork to reach a given loss, in only a small fraction of the time that it would require to actuallytrain the model. We do this by studying the training dynamics of a linearized version of the model –using the SDE in eq. (1) – which, being in the smaller function space compared to parameters space,can be solved numerically. We have also studied the dependency of training time from pre-trainingand hyper-parameters (Section 3.1), and how to make the computation feasible for larger datasets andarchitectures (Section 4).While we do not necessarily expect a linear approximation around a random initialization to holdduring training of a real (non wide) network, we exploit the fact that when using a pre-trainednetwork the weights are more likely to remain close to initialization [24], improving the qualityof the approximation. However, in the Supplementary Material we show that even when using apre-trained network, the trajectories of the weights of linearized model and of the real model candiffer substantially. On the other hand, we also show that the linearized model correctly predicts the outputs (not the weights) of the real model throughout the training, which is enough to compute theloss. We hypothesise that this is the reason why eq. (1) can accurately predict the training time usinga linear approximation.The procedure described so far can be considered as an open loop procedure meaning that, since weare estimating training time before any fine-tuning step is performed, we are not gaining any feedbackfrom the actual training. How to perform training time prediction during the actual training, and usetraining feedback (e.g., gradients updates) to improve the prediction in real time, is an interestingfuture direction of research. 8 eferences [1] Dimitris Achlioptas. Database-friendly random projections: Johnson-lindenstrauss with binarycoins.
J. Comput. Syst. Sci. , 66(4):671–687, June 2003.[2] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning viaover-parameterization. arXiv preprint arXiv:1811.03962 , 2018.[3] Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of op-timization and generalization for overparameterized two-layer neural networks. In
InternationalConference on Machine Learning , pages 322–332, 2019.[4] Sean Bell, Paul Upchurch, Noah Snavely, and Kavita Bala. Material recognition in the wildwith the materials in context database.
Computer Vision and Pattern Recognition (CVPR) , 2015.[5] Ella Bingham and Heikki Mannila. Random projection in dimensionality reduction: applicationsto image and text data. In
Proceedings of the seventh ACM SIGKDD international conferenceon Knowledge discovery and data mining , pages 245–250, 2001.[6] Alon Brutzkus, Amir Globerson, Eran Malach, and Shai Shalev-Shwartz. SGD learnsover-parameterized networks that provably generalize on linearly separable data.
CoRR ,abs/1710.10174, 2017.[7] Pratik Chaudhari and Stefano Soatto. Stochastic gradient descent performs variational inference,converges to limit cycles for deep networks.
CoRR , abs/1710.11029, 2017.[8] Pratik Chaudhari and Stefano Soatto. Stochastic gradient descent performs variational infer-ence, converges to limit cycles for deep networks. In
International Conference on LearningRepresentations , 2018.[9] Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent findsglobal minima of deep neural networks. In Kamalika Chaudhuri and Ruslan Salakhutdinov,editors,
Proceedings of the 36th International Conference on Machine Learning , volume 97 of
Proceedings of Machine Learning Research , pages 1675–1685, Long Beach, California, USA,09–15 Jun 2019. PMLR.[10] Simon S Du, Jason D Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent findsglobal minima of deep neural networks. arXiv preprint arXiv:1811.03804 , 2018.[11] Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizesover-parameterized neural networks. arXiv preprint arXiv:1810.02054 , 2018.[12] Micah Goldblum, Jonas Geiping, Avi Schwarzschild, Michael Moeller, and Tom Goldstein.Truth or backpropaganda? an empirical investigation of deep learning theory, 2019.[13] Soufiane Hayou, Arnaud Doucet, and Judith Rousseau. Mean-field behaviour of neural tangentkernel for deep neural networks. arXiv preprint arXiv:1905.13654 , 2019.[14] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence andgeneralization in neural networks. In
Advances in neural information processing systems , pages8571–8580, 2018.[15] Daniel Justus, John Brennan, Stephen Bonner, and Andrew Stephen McGough. Predicting thecomputational cost of deep learning models.
CoRR , abs/1811.11880, 2018.[16] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and PingTak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima.
CoRR , abs/1609.04836, 2016.[17] Jonathan Krause, Michael Stark, Jia Deng, and Li Fei-Fei. 3d object representations forfine-grained categorization. In , Sydney, Australia, 2013.[18] A. Krizhevsky. Learning multiple layers of features from tiny images. Master’s thesis, ComputerScience Department, University of Toronto, 2009.[19] Jaehoon Lee, Yasaman Bahri, Roman Novak, Samuel S Schoenholz, Jeffrey Pennington,and Jascha Sohl-Dickstein. Deep neural networks as gaussian processes. arXiv preprintarXiv:1711.00165 , 2017. 920] Jaehoon Lee, Lechao Xiao, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, JaschaSohl-Dickstein, and Jeffrey Pennington. Wide neural networks of any depth evolve as linearmodels under gradient descent, 2019.[21] Hao Li, Pratik Chaudhari, Hao Yang, Michael Lam, Avinash Ravichandran, Rahul Bhotika, andStefano Soatto. Rethinking the hyperparameters for fine-tuning, 2020.[22] Yuanzhi Li and Yang Yuan. Convergence analysis of two-layer neural networks with reluactivation.
CoRR , abs/1705.09886, 2017.[23] S. Maji, J. Kannala, E. Rahtu, M. Blaschko, and A. Vedaldi. Fine-grained visual classificationof aircraft. Technical report, 2013.[24] Fangzhou Mu, Yingyu Liang, and Yin Li. Gradients as features for deep representation learning.In
International Conference on Learning Representations , 2020.[25] Ariadna Quattoni and Antonio Torralba. Recognizing indoor scenes. In
CVPR , pages 413–420.IEEE Computer Society, 2009.[26] Andrew M Saxe, James L McClelland, and Surya Ganguli. Exact solutions to the nonlineardynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120 , 2013.[27] J. Shawe-Taylor, C. K. I. Williams, N. Cristianini, and J. Kandola. On the eigenspectrum ofthe gram matrix and the generalization error of kernel-pca.
IEEE Transactions on InformationTheory , 51(7):2510–2522, 2005.[28] Samuel L. Smith and Quoc V. Le. A bayesian perspective on generalization and stochasticgradient descent.
ArXiv , abs/1710.06451, 2018.[29] P. Welinder, S. Branson, T. Mita, C. Wah, F. Schroff, S. Belongie, and P. Perona. Caltech-UCSDBirds 200. Technical Report CNS-TR-2010-001, California Institute of Technology, 2010.[30] Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understandingdeep learning requires rethinking generalization. In .OpenReview.net, 2017.[31] Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Stochastic gradient descent optimizesover-parameterized deep relu networks.
CoRR , abs/1811.08888, 2018.10 redicting Training Time Without Training: Supplementary Material
In the Supplementary Material we give the pseudo-code for the training time prediction algorithm(Appendix A) together with implementation details, show additional results including prediction oftraining time using only a subset of samples, and comparison of real and predicted loss curves in avariety of conditions (Appendix C). Finally, we give proofs of all statements.
A Algorithm
Algorithm 1:
Estimate the Training Time on a given target dataset and hyper-parameters. Data:
Number of steps T to simulate, threshold (cid:15) to determine convergence, pre-trained weights w of the model, a target dataset with images X = { x i } Ni =1 and labels Y = { y i } Ni =1 , batch size B , learning rate η , momentum m ∈ [0 , . Result:
An estimate ˆ T (cid:15) of the number of steps necessary to converge within (cid:15) to the final value T (cid:15) := min { t : |L t − L T | < (cid:15) } . Initialization:
Compute initial network predictions f ( X ) , estimate Θ using randomprojections (Section 4), compute the ELR ˜ η = η/ (1 − m ) to use in eq. (1) instead of η ; if B = N then Get f lin t ( X ) solving the ODE in eq. (1) (only the deterministic part ) for T steps; else Get f lin t ( X ) solving the SDE in eq. (1) for T steps (see approximation in Section 4); end if Using f lin t ( X ) and Y compute linearized loss L lint ∀ t ∈ { , ..., T } return ˆ T (cid:15) := min { t : |L lint − L linT | < (cid:15) } ;We can compute the estimate on training time based also on the accuracy of the model: we straight-forwardly modify the above algorithm and use the predictions f lin t ( X ) to compute the error instead ofthe loss (e.g. fig. 10).We now briefly describe some implementations details regarding the numerical solution of ODE andSDE. Both of them can be solved by means of standard algorithms: in the ODE case we used LSODA(which is the default integrator in scipy.integrate.odeint ), in the SDE case we used Euler-Maruyamaalgorithm for Ito equations.We observe removing batch normalization (preventing the statistics to be updated) and removing dataaugmentation improve linearization approximation both in the case of GD and SGD. Interestinglydata augmentation only marginally alters the spectrum of the Gram matrix Θ and has little impacton the linearization approximation w.r.t. batch normalization. [12] observed similar effects but,differently from us, their analysis has been carried out using randomly initialized ResNets. B Target datasets
Dataset
Number of images Classes Mean samples per class Imbalance factorcifar10 [18] 50000 10 5000 1cifar100 [18] 50000 100 500 1cub200 [29] 5994 200 29.97 1.03fgvc-aircrafts [23] 6667 100 66.67 1.02mit67 [25] 5360 67 80 1.08opensurfacesminc2500 [4] 48875 23 2125 1.03stanfordcars [17] 8144 196 41.6 2.83
Table 2: Target datasets.11
Additional Experiments
Prediction of training time using a subset of samples.
In Section 4 we suggest that in the caseof MSE loss, it is possible to predict the training time on a large dataset using a smaller subset ofsamples (we discuss the details in Appendix D). In Figure 7 we show the result of predicting the losscurve on a dataset of N = 4000 samples using a subset of N = 1000 samples. Similarly, in Figure 11(top row) we show the more difficult example of predicting the loss curve on N = 1000 samplesusing a very small subset of N = 100 samples. In both cases we correctly predict that training on alarger dataset is slower, in particular we correctly predict the asymptotic convergence phase. Note inthe case N = 100 the prediction is less accurate, this is in part due to the eigenspectrum of Θ beingstill far from its limiting behaviour achieved for large number of data (see Appendix D). Comparison of predicted and real error curve.
In Figure 8 we compare the error curve predictedby our method and the actual train error of the model as a function of the number of optimization steps.The model is trained on a subset of 2 classes of CIFAR-10 with 150 samples. We run the comparisonfor both gradient descent (left) and SGD (right), using learning rate η = 0 . , momentum m = 0 and (in the case of SGD) batch size 100. In both cases we observe that the predicted curve isreasonably close to the actual curve, more so at the beginning of the training (which is expected,since the linear approximation is more likely to hold). We also perform an ablation study to see theeffect of different approximation of SGD noise in the SDE in eq. (1). In Figure 8 (center) we estimatethe variance of the noise of SGD at the beginning of the training, and then assume it is constant tosolve the SDE. Notice that this predicts the wrong asymptotic behavior, in particular the predictederror does not converge to zero as SGD does. In Figure 8 (right) we rescale the noise as we suggest inSection 4: once the noise is rescaled the SDE is able to predict the right asymptotic behavior of SGD. Prediction accuracy in weight space and function space.
In Section 3 and Section 6 we argue thatusing a differential equation to predict the dynamics in function space rather than weight space is notonly faster (in the over-parametrized case), but also more accurate. In Figure 9 we show empiricallythat solving the corresponding ODE in weight space leads to a substantially larger prediction error.
Effective learning rate.
In Section 3.1 we note that as long as the effective learning rate ˜ η = η/ (1 − m ) remains constant, runs with different learning rate η and momentum m will have similarlearning curve. We show a formal derivation in Appendix E. In Figure 12 we show additionalexperiments, similar to Figure 2, on several other datasets to further confirm this point. Point-wise similarity of predicted and observed loss curve.
In some cases, we observe that thepredicted and observed loss curves can differ. This is especially the case when using cross-entropyloss (Figure 10). We hypothesize that this may be due to improper prediction of the dynamics whenthe softmax output saturates, as the dynamic becomes less linear [20]. However, the train error curve(which only depends on the relative order of the outputs) remains relatively correct. We should alsonotice that prediction of the (cid:15) -training-time ˆ T (cid:15) can be accurate even if the curves are not point-wiseclose. The (cid:15) -training-time seeks to find the first time after which the loss or the error is within an (cid:15) threshold. Hence, as long as the real and predicted loss curves have a similar asymptotic slope theprediction will be correct, as we indeed verify in Figure 10 (bottom). Iterations T r a i n L o ss size 1000size 4000predicted Threshold TT Real TTEstimated TT
Figure 7:
Training-time prediction using a subset of the data. (Left)
Using the method describedin Appendix D, we predict (green) the loss curve on a large dataset of N = 4000 samples (orange)using a subset of N = 1000 samples (blue). In Figure 11 we show a similar result using a muchsmaller subset of N = 100 samples. (Right) Corresponding estimated training time on the largerdataset at different thresholds (cid:15) compared to the real training time on the larger dataset.12igure 8: (Left)
Comparison of the real error curve on CIFAR10 using gradient descent and thepredicted curve. (Center)
Same as before, but this time we train using SGD and compare it withthe prediction using the technique described in Section 4 to approximate the covariance of the SGDnoise that appears in the SDE in eq. (1). (Right)
Same as (center), but using constant noise instead ofrescaling the noise using the value of the loss function as described in Section 4. Note that in thiscase we do not capture the right asymptotic behavior of SGD.
D Prediction of training time on larger datasets
In Section 4 we suggest that, in the case of MSE loss, it is possible to predict the training time ona large dataset using a subset of the samples. To do so we leverage the fact that the eigenvalues of Θ follows a power-law which is independent on the size of the dataset for large enough sizes (seeFigure 4, right). More precisely, from Proposition 2, we know that given the eigenvalues λ k of Θ andthe projections p k = δ y · v k it is possible to predict the loss curve using L t = (cid:88) k p k e − ηλ k t . Let Θ be the Gram-matrix of the gradients computed on the small subset of N samples, and let Θ be the Gram-matrix of the whole dataset of size N . Using the fact that, as we increase the number ofsamples, the eigenvalues (once normalized by the dataset size) converge to a fixed limit (Figure 4,right), we estimate the eigenvalues λ k of Θ as follow: we fit the coefficients s and c of a power law λ k = ck − s to the eigenvalues of Θ , and use the same coefficients to predict the eigenvalues of Θ .However, we notice that the coefficient s (slope of the power law) estimated using a small subset ofthe data is often smaller than the slope observed on larger datase (note in Figure 4 (right) that thecurves for smaller datasets are more flat). We found that using the following corrected power lawincreases the precision of the prediction: ˆ λ k = ck − s + α (cid:0) N N − (cid:1) . Empirically, we determined α ∈ [0 . , . to give a good fit over different combinations of N and N .In Figure 11 (center) we compare the predicted eigenspectrum of Θ with the actual eigenspectrum of Θ .The projections p k follow a similar power-law – albeit more noisy (see Figure 11, right) – so directlyfitting the data may give an incorrect result. However, notice that in this case we can exploit an Iterations R e l a t i v e E rr o r Weights relative approx errorOutput relative approx error
Figure 9:
Comparison of prediction accuracy in weight space vs. function space.
We comparethe result of using the deterministic part of eq. (1) to predict the weights w t at time t and the outputs f t ( X ) of the networks under GD. The relative error in predicting the outputs is much smaller than therelative error of predicting the weights at all times. This, together with the computational advantage,motivates the decision of using eq. (1) to predict the behavior in function space.13
20 40 60 80 100 120 140
Iterations T r a i n L o ss Loss
Iterations T r a i n E rr Err
TrueLinearized
Threshold TT Threshold TT Mit67
Iterations T r a i n L o ss Loss
Iterations T r a i n E rr Err
TrueLinearized
Threshold TT Threshold TT CIFAR10
Figure 10:
Training time prediction is accurate even if loss curve prediction is not. (Top row)
Loss curve and error curve prediction on MIT-67 (left) and CIFAR-10 (right). (Bottom row)
Predictedtime to reach a given threshold (orange) vs real training time (blue). We note that on some datasetsour loss curve prediction differs from the real curve near convergence. However, since our trainingtime definition measures the time to reach the asymptotic value (which is what is most useful inpractice) rather than the time reach an absolute threshold, this does not affect the accuracy of theprediction (see Appendix C).
Iterations T r a i n L o ss size 100size 1000predicted Ordered eigenvalues Sp e c t r u m Eig small datasetEig large datasetEig predicted
Ordered eigenvalues y f () P r o j e c t i o n s Large dataset projectionSmall dataset projectionPrediction
Figure 11:
Training time prediction using a subset of the data. (Left)
We predict the loss curveon a large dataset of N = 1000 samples using a subset of N = 100 samples on CIFAR10 (similarresults hold for other datasets presented so far). (Center) Eigenspectrum of Θ computed using N = 100 samples (orange), N = 1000 samples (green) and predicted spectrum using our method(blue). (Right) Value of the projections p k of δ y on the eigenvectors of Θ , computed at N = 100 (orange) and N = 1000 (blue). Note that while they approximatively follow a power-law on average,it is much more ‘noisy’ than that of the eigenvalues. In green we show the predicted trend using ourmethod.additional constraint, namely that (cid:80) k p k = (cid:107) δ y (cid:107) ( (cid:107) δ y (cid:107) is a known quantity: labels and initialmodel predictions on the large dataset). Let p k = δ y · v k and let p (cid:48) k = δ y · v (cid:48) k where v k and v (cid:48) k are the eigenvectors of Θ and Θ respectively. Fix a small k (in our experiments, k = 100 ). Byconvergence laws [27], we have that p (cid:48) k (cid:39) p k when k < k . The remaining tail of p k for k > k must now follow a power-law and also be such that (cid:80) k p k = (cid:107) δ y (cid:107) . This uniquely identify thecoefficients of a power law. Hence, we use the following prediction rule for p k : ˆ p k = (cid:26) p (cid:48) k if k < k ak − b if k ≥ k where a and b are such that ˆ p k = p (cid:48) k and (cid:80) k ˆ p k = (cid:107) δ y (cid:107) .In Figure 11 (left), we use the approximated ˆ λ k and ˆ p k to predict the loss curve on a dataset of N = 1000 samples using a smaller subset of N = 100 samples. Notice that we correctly predictthat the convergence is slower on the larger dataset. Moreover, while training on the smaller datasetquickly reaches zero, we correctly estimate the much slower asymptotic phase on the larger dataset.Increasing both N and N increases the accuracy of the estimate, since the eigenspectrum of Θ iscloser to convergence: In Figure 7 we show the same experiment as Figure 11 with N = 1000 and N = 4000 . Note the increase in accuracy on the predicted curve.14igure 12: Additional experiments on the effective learning rate.
We show additional plotsshowing the error curves obtained on different datasets using different values of the effective learningrate ˜ η = η/ (1 − m ) , where η is the learning rate and m is the momentum. Each line is the observederror curve of a model trained with a different learning rate η and momentum m . Lines with the samecolor have the same ELR ˜ η , but each has a different η and m . As we note in Section 3.1, as long as ˜ η remains the same, training dynamics with different hyper-parameters will have similar error curves. E Effective learning rate
We now show that having a momentum term has the effect of increasing the effective learning rate inthe deterministic part of eq. (1). A similar treatment of the momentum term is also in [28, AppendixD]. Consider the update rule of SGD with momentum: a t +1 = m a t + g t +1 ,w t +1 = w t − η a t +1 , If η is small, the weights w t will change slowly and we can consider g t to be approximately constanton short time periods, that is g t +1 = g . Under these assumptions, the gradient accumulator a t satisfies the following recursive equation: a t +1 = m a t + g, which is solved by (assuming a = 0 as common in most implementations): a t = (1 − m t ) g − m . In particular, a t converges exponentially fast to the asymptotic value a ∗ = g/ (1 − m ) . Replacingthis asymptotic value in the weight update equation above gives: w t +1 = w t − ηa ∗ = w t − η − m g = w t − ˜ η g, that is, once a t reaches its asymptotic value, the weights are updated with an higher effective learningrate ˜ η = η − m . Note that this approximation remains true as long as the gradient g t does not changemuch in the time that it takes a t to reach its asymptotic value. This happens whenever the momentum m is small (since a t will converge faster), or when η is small ( g t will change more slowly). Forlarger momentum and learning rate, the effective learning rate may not properly capture the effect ofmomentum. F Proof of theorems
F.1 Proposition 1: SDE in function space for linearized networks trained with SGD
We now prove our Proposition 1 and show how we can approximate the SGD evolution in functionspace rather than in parameters space. We follow the standard method used in [13] to derive a generalSDE for a DNN, then we speciaize it to the case of linearized deep networks. Our notation follows[20], we define f θ t ( X ) = vec ([ f t ( x )] x ∈X ) ∈ R CN the stacked vector of model output logits for allexamples, where C is the number of classes and N the number of samples in the training set.15o describe SGD dynamics in function space we start from deriving the SDE in parameter space. Inorder to derive the SDE required to model SGD we will start describing the discrete update of SGDas done in [13]. θ t +1 = θ t − η ∇ θ L B ( θ t ) (6)where L B ( θ t ) = L ( f θ t ( X B ) , Y B ) is the average loss on a mini-batch B (for simplicity, we assumethat B is a set of indexes sampled with replacement).The mini-batch gradient ∇ θ L B ( θ t ) is an unbiased estimator of the full gradient, in particular thefollowing holds: E [ ∇ θ L B ( θ t )] = 0 cov[ ∇ θ L B ( θ t )] = Σ( θ t ) | B | (7)Where we defined the covariance of the gradients as: Σ( θ t ) := E (cid:2) ( g i ∇ f t ( x i ) L ) ⊗ ( g i ∇ f t ( x i ) L ) (cid:3) − E (cid:2) g i ∇ f t ( x i ) L (cid:3) ⊗ E (cid:2) g i ∇ f t ( x i ) L (cid:3) and g i := ∇ w f ( x i ) . The first term in the covariance is the second order moment matrix while thesecond term is the outer product of the average gradient.Following standard approximation arguments (see [7] and references there in) in the limit of smalllearning rate η we can approximate the discrete stochastic equation eq. (6) with the SDE: dθ t = − η ∇ θ L ( θ t ) dt + η (cid:112) | B | Σ( θ t ) dn (8)where n ( t ) is a Brownian motion.Given this result, we are going now to describe how to derive the SDE for the output f t ( X ) of thenetwork on the train set X . Using Ito’s lemma (see [13] and references there in), given a randomvariable θ that evolves according to an SDE, we can obtain a corresponding SDE that describes theevolution of a function of θ . Applying the lemma to f θ ( X ) we obtain: df t ( X ) = [ − η Θ t ∇ f t L ( f t ( X ) , Y ) + 12 vec ( A )] dt + η (cid:112) | B | ∇ θ f ( X )Σ( θ t ) dn (9)where ∇ θ f ( X ) ∈ R CN × D is the jacobian matrix and D is the number of parameters. Note A is a N × C matrix which, denoting by f ( j ) θ ( x ) the j -th output of the model on a sample x , is given by: A ij = tr[Σ( θ t ) ∇ θ f ( j ) θ ( x i )] . Using the fact that in our case the model is linearized, so f θ ( x ) is a linear function of θ , we have that ∇ θ f ( j ) ( x ) = 0 and hence A = 0 . This leaves us with the SDE: df t ( X ) = − η Θ t ∇ f t L dt + η (cid:112) | B | ∇ θ f ( X )Σ( θ t ) dn (10)as we wanted. F.2 Proposition 2: Loss decomposition