How Neural Networks Extrapolate: From Feedforward to Graph Neural Networks
Keyulu Xu, Mozhi Zhang, Jingling Li, Simon S. Du, Ken-ichi Kawarabayashi, Stefanie Jegelka
HHow Neural Networks Extrapolate:From Feedforward to Graph Neural Networks
Keyulu Xu † Mozhi Zhang ‡ Jingling Li § Simon S. Du ¶ Ken-ichi Kawarabayashi || Stefanie Jegelka ** Abstract
We study how neural networks trained by gradient descent extrapolate , i.e., what they learn outsidethe support of the training distribution. Previous works report mixed empirical results when extrapolatingwith neural networks: while multilayer perceptrons (MLPs) do not extrapolate well in certain simple tasks,Graph Neural Network (GNN), a structured network with MLP modules, has shown some success in morecomplex tasks. Working towards a theoretical explanation, we identify conditions under which MLPs andGNNs extrapolate well. First, we quantify the observation that ReLU MLPs quickly converge to linearfunctions along any direction from the origin, which implies that ReLU MLPs do not extrapolate mostnon-linear functions. But, they provably learn a linear target function when the training distribution issufficiently “diverse”. Second, in connection to analyzing successes and limitations of GNNs, these resultssuggest a hypothesis for which we provide theoretical and empirical evidence: the success of GNNsin extrapolating algorithmic tasks to new data (e.g., larger graphs or edge weights) relies on encodingtask-specific non-linearities in the architecture or features. Our theoretical analysis builds on a connectionof overparameterized networks to the neural tangent kernel. Empirically, our theory holds across differenttraining settings. † Massachusetts Institute of Technology. Email: [email protected] ‡ University of Maryland. Email: [email protected] § University of Maryland. Email: [email protected] ¶ University of Washington. Email: [email protected] || National Institute of Informatics. Email: [email protected] ** Massachusetts Institute of Technology. Email: [email protected] a r X i v : . [ c s . L G ] N ov ontents B.1 Proof of Theorem 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 18B.2 Proof of Lemma 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21B.3 Proof of Theorem 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26B.4 Proof of Corollary 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29B.5 Proof of Theorem 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29B.6 Proof of Lemma 10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32B.7 Proof of Lemma 11 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33
C Experimental Details 34
C.1 Learning Simple Non-Linear Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36C.2 R-squared for Out-of-distribution Directions . . . . . . . . . . . . . . . . . . . . . . . . . . 38C.3 Learning Linear Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38C.4 MLPs with cosine, quadratic, and tanh Activation . . . . . . . . . . . . . . . . . . . . . . . 39C.5 Max Degree . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 40C.6 Shortest Path . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 41C.7 N-Body Problem . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 42
D Visualization and Additional Experimental Results 45
D.1 Visualization Results . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 45D.2 Experimental Results . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 53
Introduction
Humans extrapolate well in many tasks. For example, we can apply arithmetics to arbitrarily large numbers.One may wonder whether a neural network can do the same and generalize to examples arbitrarily far from thetraining data [Santoro et al., 2018]. Curiously, previous works report mixed extrapolation results with neuralnetworks. Early works demonstrate feedforward neural networks, a.k.a. multilayer perceptrons (MLPs),fail to extrapolate well when learning simple polynomial functions [Barnard and Wessels, 1992, Haley andSoloway, 1992]. However, recent works show Graph Neural Networks (GNNs) [Scarselli et al., 2009], a classof structured networks with MLP building blocks, can generalize to graphs much larger than training graphsin challenging algorithmic tasks, such as predicting the time evolution of physical systems [Battaglia et al.,2016], learning graph algorithms [Velickovic et al., 2020], and solving mathematical equations [Lample andCharton, 2020].To explain this puzzle, we formally study how neural networks trained by gradient descent (GD) extrapo-late, i.e., what they learn outside the support of training distribution. We say a neural network extrapolateswell if it learns a task outside the training distribution. At first glance, it may seem that neural networkscan behave arbitrarily outside the training distribution since they have high capacity [Zhang et al., 2017]and are universal approximators [Cybenko, 1989, Funahashi, 1989, Hornik et al., 1989, Kurková, 1992].However, neural networks are constrained by gradient descent training [Hardt et al., 2016, Soudry et al.,2018]. In our analysis, we explicitly consider such implicit bias through the analogy of the training dynamicsof over-parameterized neural networks and kernel regression via the neural tangent kernel (NTK) [Jacot et al.,2018].We begin with MLPs, the simplest neural networks and building blocks of more complex architecturessuch as GNNs. First, we show that the predictions of over-parameterized MLPs with ReLU activationtrained by GD converge to linear functions along any direction from the origin. We prove a convergencerate for two-layer networks (Theorem 3) and empirically observe that convergence often occurs close to (butoutside) the training data (Fig. 1), which suggests ReLU MLPs cannot extrapolate well for most non-lineartasks. We emphasize that our results do not follow from the fact that ReLU MLPs have finitely many linearregions [Arora et al., 2018, Hanin and Rolnick, 2019, Hein et al., 2019]. While having finitely many linearregions implies ReLU MLPs eventually become linear, it does not say whether MLPs will learn the correcttarget function close to the training distribution. In contrast, our results are non-asymptotic and quantify whatkind of functions MLPs will learn close to the training distribution. Second, we identify a condition whenMLPs extrapolate well: the task is linear and the geometry of the training distribution satisfies a condition(Theorem 5). To our knowledge, our results are the first extrapolation results of this kind for feedforwardneural networks.Next, we relate our insights into MLPs to GNNs, to explain why GNNs extrapolate well in somealgorithmic tasks. Prior works report successful extrapolation for tasks that can be solved by dynamicprogramming (DP) [Bellman, 1966], which has a similar computation structure as GNNs [Xu et al., 2020].The DP updates can be decomposed into non-linear and linear steps. Hence, we hypothesize that GNNstrained by GD can extrapolate well in a DP task, if we encode appropriate non-linearity in the architecture and input representation (Fig. 2). Importantly, encoding non-linearity may be unnecessary for GNNsto interpolate , because the MLP modules can easily learn many non-linear functions inside the trainingdistribution [Cybenko, 1989, Hornik et al., 1989, Xu et al., 2020], but encoding non-linearity is crucial forGNNs to extrapolate correctly. We prove this hypothesis for a simplified case using
Graph NTK [Du et al.,2019b]. Empirically, we validate the hypothesis on three DP tasks: max degree, shortest paths, and n -bodyproblem. We show GNNs with appropriate architecture, input representation, and training distribution canpict well on graphs with unseen sizes, structures, edge weights, and node features. Our theory explains theempirical success in previous works and suggests their limitations: successful extrapolation relies on encodingtask-specific non-linearity, which requires domain knowledge or extensive model search. In Section 5, we1igure 1: How ReLU MLPs extrapolate.
We train MLPs to learn non-linear functions (grey) and plot theirpredictions both within (blue) and outside (black) the training distribution. MLPs converge quickly to linearfunctions outside the training data range along directions from the origin (Theorem 3). Hence, MLPs do notextrapolate well in most non-linear tasks. But, with appropriate training data, MLPs can extrapolate globallylinear target functions well (Theorem 5). Σ v GNN Architectures DP Algorithm (Target Function)
MLP has to learn non-linear stepsMLP learns linear steps d [ k ][ u ] = d [ k − 1][ v ] + w ( v , u ) h ( k ) u = MLP ( k ) ( h ( k −1) v , h ( k −1) u , w ( v , u ) ) h ( k ) u = MLP ( k ) ( h ( k −1) v , h ( k −1) u , w ( v , u ) ) min v min v (a) Network architecture : hard for extrapolation: input graph : input transform : easier for extrapolation f f ( G ) f ( G ) hg g h h ( G ) G (b) Input representation. Figure 2:
How GNNs extrapolate.
Since MLPs can extrapolate well when learning linear functions, wehypothesize that GNNs can extrapolate well in dynamic programming (DP) tasks if we encode appropriatenon-linearity in the architecture (left) and/or input representation (right; through domain knowledge orrepresentation learning). The encoded non-linearities may not be necessary for interpolation, as they canbe approximated by MLP modules, but they help extrapolation. We support the hypothesis theoretically(Theorem 9) and empirically (Fig. 6).also discuss relations of our results to other out-of-distribution settings.In summary, we study how MLPs and GNNs extrapolate. First, ReLU MLPs trained by GD convergeto linear functions along directions from the origin with a rate of O (1 /(cid:15) ) . Second, to explain why GNNsextrapolate well in some algorithmic tasks, we prove that ReLU MLPs can extrapolate well in linear tasks,leading to a hypothesis: GNNs can extrapolate well when appropriate non-linearity is encoded into thearchitecture and features. We prove this hypothesis for a simplified case and provide empirical support formore general settings. All our claims are supported with experiments. Early works show example tasks where MLPs do not extrapolate well, e.g. learning simple polynomi-als [Barnard and Wessels, 1992, Haley and Soloway, 1992]. We instead show a general pattern of how ReLUMLPs extrapolate and identify conditions for MLPs to extrapolate well. More recent works study the implicitbiases induced on MLPs by gradient descent, for both the “NTK” and “adaptive” regimes [Bietti and Mairal,2019, Chizat and Bach, 2018, Li et al., 2019, Song et al., 2018]. Related to our results, some works showMLP predictions converge to “simple” piecewise linear functions, e.g., with few linear regions [Hanin andRolnick, 2019, Maennel et al., 2018, Savarese et al., 2019, Williams et al., 2019]. Our work differs in thatnone of these works explicitly studies extrapolation, and some focus only on one-dimensional inputs. Recentworks also show that in high-dimensional settings of the NTK regime, MLP is asymptotically at most alinear predictor in certain scaling limits [Ba et al., 2020, Ghorbani et al., 2019]. We study a different setting2extrapolation), and our analysis is non-asymptotic in nature and does not rely on random matrix theory.Prior works explore GNN extrapolation by testing on larger graphs [Battaglia et al., 2018, Santoro et al.,2018, Saxton et al., 2019, Velickovic et al., 2020]. We are the first to theoretically study GNN extrapolation,and we complete the notion of extrapolation to include unseen features and structures.
We begin by introducing our setting. Let X be the domain of interest, here, vectors or graphs. The task is tolearn an underlying function g : X → R with a training set { ( x i , y i ) } ni =1 ⊂ D , where y i = g ( x i ) and D isthe support of training distribution.Previous works have extensively studied in-distribution generalization where the training and the testdistributions are identical [Valiant, 1984, Vapnik, 2013]; i.e., D = X . In contrast, extrapolation addressespredictions on a domain X that is larger than the support of the training distribution D . We will say thata model extrapolates well if it has small extrapolation error , the maximum test error outside the trainingsupport D . Definition 1. (Extrapolation error). Suppose f : X → R is a model trained on { ( x i , y i ) } ni =1 ⊂ D . We definethe extrapolation error of f on X as (cid:107) f − g (cid:107) ∞ , X \D = sup {| f ( x ) − g ( x ) | : x ∈ X \ D} .We focus on neural networks trained by gradient descent (GD) or its variants with mean squared loss .We study two neural network architectures: MLPs and GNNs. Graph Neural Networks.
GNNs are structured networks operating on graphs with MLP modules [Battagliaet al., 2018, Xu et al., 2019]. The input is a graph G = ( V, E ) . Each node u ∈ V has a feature vector x u , andeach edge ( u, v ) ∈ E has a feature vector w ( u,v ) . GNNs iteratively compute a representation for each node.Initially, the node representations are the node features: h (0) u = x u . In iteration k = 1 ..K , a GNN updates thenode representations h ( k ) u by aggregating the neighboring nodes’ representations with MLP modules [Gilmeret al., 2017, Xu et al., 2018]. We can optionally compute a graph representation h G by aggregating the finalnode representations with another MLP. Formally, h ( k ) u = (cid:88) v ∈N ( u ) MLP ( k ) (cid:16) h ( k − u , h ( k − v , w ( v,u ) (cid:17) , h G = MLP ( K +1) (cid:16) (cid:88) u ∈ G h ( K ) u (cid:17) . (1)The final output is the graph representation h G or final node representations h ( K ) u depending on the task.We refer to the neighbor aggregation step for h ( k ) u as aggregation and the pooling step in h G as readout .Previous works typically use sum-aggregation and sum-readout [Battaglia et al., 2018]. Our results indicatewhy replacing them may help extrapolation (Section 4). MLPs are the simplest neural networks and building blocks of more complex networks such as GNNs, so wefirst study how MLPs trained by GD extrapolate. In this paper, we assume that MLPs have ReLU activationfunctions. Section 3.3 contains preliminary results for other activations.
By architecture, ReLU networks learn piecewise linear functions, but what do these regions look like outsidethe support of the training data? Fig. 1 illustrates examples of how ReLU networks extrapolate when trained3igure 3:
Conditions for MLPs to extrapolate well when learning linear target functions.
We train MLPsto learn 2D linear functions (grey) with different training distributions (blue) and plot out-of-distributionpredictions (black). Following Theorem 5, MLPs extrapolate well when the training distribution (blue) hassupport in all directions (first panel), but not otherwise: in the two middle panels, some dimensions of thetraining data are constrained to be positive (red arrows); in the last panel, one dimension of the training datais a fixed constant. quadratic cos sqrt linear l110 M APE (a) Different target functions all fix1 neg1 neg16 neg3210 M APE (b) Different training distributions for linear target
Figure 4:
Distribution of mean absolute percentage error (MAPE) for extrapolation.
We train ReLUMLPs with various hyperparameters (depth, width, learning rate, batch size) and compute MAPE on testexamples (Appendix C). We plot distributions of test errors outside the training support, from many trialswith different training/test distributions and hyperparameters. (a) Extrapolation for learning different targetfunctions; (b) different training distributions for learning linear target functions: “all” covers all directions,“fix1” has one dimension fixed to a constant, and “neg d ” has d dimensions constrained to negative values.Results align with our theory: MLPs generally do not extrapolate well, unless the target function is linearalong each direction (Fig. 4a). For linear target functions, MLPs extrapolate well if the training distributioncovers all directions (Fig. 4b and 3).on various nonlinear functions. These examples suggest that outside the training support, the predictionsquickly become linear along directions from the origin. We empirically verify this pattern via linear regressionon MLPs’ predictions (Appendix C.2). Outside the training data range, along any directions from the origin,the coefficient of determination ( R ) is always greater than 0.99; i.e., MLPs “linearize" almost immediatelyoutside the training data range.We formalize this observation using the implicit biases of neural networks trained by GD via the neuraltangent kernel (NTK) : optimization trajectories of overparameterized networks trained by GD are equivalentto those of kernel regression with a specific neural tangent kernel, under a set of assumptions called the “NTKregime” [Jacot et al., 2018]. We provide an informal definition here; for further details, we refer the readersto Jacot et al. [2018] and Appendix A. Definition 2. (Informal) A neural network trained in the
NTK regime is infinitely wide, randomly initializedwith certain scaling, and trained by GD with infinitesimally small steps and squared loss.Previous works analyze optimization and in-distribution generalization of overparameterized neuralnetworks with NTK [Allen-Zhu et al., 2019a,b, Arora et al., 2019a,b, Cao and Gu, 2019, Du et al., 2019c,a,Jacot et al., 2018, Lee et al., 2019, Li and Liang, 2018]. We instead analyze extrapolation.4heorem 3 formalizes our observation from Fig. 1: outside the training data range, along any direction t v from the origin, the prediction of a two-layer ReLU MLP quickly converges to a linear function with rate O ( t ) . The linear coefficients β v and the constant terms in the convergence rate depend on the training dataand direction v . The proof is in Appendix B.1. Theorem 3.
Suppose we train a two-layer ReLU MLP f : R d → R with squared loss in the NTK regime.For any direction v ∈ R d , let x = t v . As t → ∞ , f ( x + h v ) − f ( x ) → β v · h for any h > , where β v is a constant linear coefficient. Moreover, given (cid:15) > , for t = O ( (cid:15) ) , we have | f ( x + h v ) − f ( x ) h − β v | < (cid:15) . Previous works show that ReLU MLPs have finitely many linear regions [Arora et al., 2018, Hanin andRolnick, 2019], which implies that their predictions eventually become linear. In contrast, Theorem 3 is amore fine-grained analysis of how
MLPs extrapolate and provides a convergence rate. While Theorem 3assumes two-layer networks in the NTK regime, experiments confirm that the linear extrapolation behaviorhappens across networks with different depths, widths, learning rates, and batch sizes (Appendix C.1 and C.2).Our proof technique potentially also extends to deeper networks.Theorem 3 also suggests which target functions a ReLU MLP may be able to match outside the trainingdata: only functions that are almost-linear along the directions away from the origin. Indeed, the results inFig. 4a (details in Appendix C.1) show that, outside the training data, the predictions do not match targetfunctions such as x (cid:62) A x (quadratic), (cid:80) di =1 cos(2 π · x ( i ) ) (cos), and (cid:80) di =1 √ x ( i ) (sqrt), where x ( i ) is the i -th dimension of input vector x . In contrast, with suitable hyperparameters, MLPs extrapolate the L1 normcorrectly (Fig. 4a), which satisfies the directional linearity condition.Fig. 4a provides one more positive result: MLPs extrapolate linear target functions well, across manydifferent hyperparameters. While learning linear functions may seem very limited at first, in Section 4 thisinsight will help explain extrapolation properties of GNNs in non-linear practical tasks. Before that, we firsttheoretically analyze when MLPs extrapolate well. Fig. 4a shows that MLPs can extrapolate well when the target function is linear. However, this is notalways true. In this section, we show that successful extrapolation depends on the geometry of training data.Intuitively, the training distribution must be “diverse” enough for correct extrapolation.We provide two conditions that relate the geometry of the training data to extrapolation. Lemma 4 statesthat overparameterized MLPs can learn a linear target function with only d examples. Lemma 4.
Suppose the target function is g ( x ) = β (cid:62) x for some β ∈ R d . Suppose the training set { x i } ni =1 contains an orthogonal basis { ˆ x i } di =1 and its opposite vectors {− ˆ x i } di =1 . If we train a two-layer ReLU MLP f on { ( x i , y i ) } ni =1 with squared loss in the NTK regime, then f ( x ) = β (cid:62) x for all x ∈ R d . Lemma 4 is mainly of theoretical interest, as the d examples need to be carefully chosen. Theorem 5builds on Lemma 4 and identifies a more practical condition for successful extrapolation: if the support of thetraining distribution covers all directions (e.g., a hypercube that covers the origin), MLPs in the NTK regimeconverge to a linear target function with sufficient training data. Theorem 5.
Suppose the target function is g ( x ) = β (cid:62) x for some β ∈ R d . Suppose the training data { x i } ni =1 is sampled from a distribution whose support D contains a connected subset S , where for anynon-zero w ∈ R d , there exists k > so that k w ∈ S . If we train a two-layer ReLU MLP f : R d → R on { ( x i , y i ) } ni =1 with squared. loss in the NTK regime, then f ( x ) p −→ β (cid:62) x as n → ∞ . uad cos tanh linear10 M APE (a) tanh activation quad cos tanh linear10 M APE (b) cosine activation quad2 quad4 cos tanh linear10 M APE (c) quadratic activation
Figure 5:
Distribution of MAPE for other activations . MLPs can extrapolate well when the activation issimilar to the target function. When learning quadratic with quadratic activations, 2-layer models (quad-2)extrapolate well, but 4-layer models (quad-4) do not. Details are in Appendix C.4.
Experiments: geometry of training data affects extrapolation.
The condition in Theorem 5 formalizesthe intuition that the training distribution must be “diverse” for successful extrapolation, i.e., D must includeall directions. Empirically, the extrapolation error is indeed small when the condition of Theorem 5 is satisfied(“all” in Fig. 4b). In contrast, the extrapolation error is much larger when the training examples are restrictedto only some directions (Fig. 4b and Fig. 3).Theorem 5 also suggests why spurious correlations hurt extrapolation, complementing the causalityarguments from previous works [Arjovsky et al., 2019, Peters et al., 2016, Rojas-Carulla et al., 2018]. Whenthe training data has spurious correlations, some combinations of features are missing; e.g., camels mightonly appear in deserts in an image collection. Therefore, the condition for Theorem 5 no longer holds, andthe model may extrapolate incorrectly.Theorem 5 is analogous to an identifiability condition for linear models, but stricter. We can uniquelyidentify a linear function if the training data has full (feature) rank. MLPs are more expressive, so identifyingthe linear target function requires additional constraints.In summary, we analyze how MLPs extrapolate and provide two insights: (1) MLPs cannot extrapolatemost non-linear tasks, because they quickly converge to directionally linear functions (Theorem 3); and (2)MLPs can extrapolate well when the target function is linear, provided the training distribution is “diverse”(Theorem 5). In the next section, these results will help us understand how more complex networks extrapolate,specifically, GNNs for non-linear algorithmic tasks. Before moving on to GNNs, we complete the picture of MLPs with experiments on other activation functions:tanh σ ( x ) = tanh( x ) , cosine σ ( x ) = cos( x ) [Lapedes and Farber, 1987, McCaughan, 1997, Sopena andAlquezar, 1994], and quadratic σ ( x ) = x [Du and Lee, 2018, Livni et al., 2014]. Details are in Appendix C.4.MLPs extrapolate well when the activation and target function are similar; e.g., tanh activation extrapolateswell when learning tanh, but not other functions (Figure 5). Moreover, each activation function has differentlimitations. When learning a tanh with tanh activations, the training data range has to be sufficiently widefor successful extrapolation. When learning a quadratic with quadratic activations, only two-layer networksextrapolate well—more layers lead to higher-order polynomials. Cosine activations are hard to optimize andcannot fit high-dimensional cosine functions well (even on training set), so we only use one/two dimensionalcosine target functions in Figure 5. We leave a theoretical analysis to future work.6 xtrapolate interpolate extrapolate interpolate70.6 6.5 43.8 6.10.0 0.0 0.0 0.0sum pooling max/min poolingmax degree shortest path (a) Importance of architecture. extrapolate dist extrapolate mass interpolate11.0 6.3 1.21.5 1.1 0.7original features improved featuresn-body problem (b) Importance of representation. Figure 6:
Extrapolation for algorithmic tasks.
Each column indicates the task and mean average percentageerror (MAPE). Encoding appropriate non-linearity in the architecture or representation is less helpful for interpolation , but significantly improves extrapolation . Left: In max degree and shortest path, GNNs thatappropriately encode max/min extrapolate well, but GNNs with sum-pooling do not. Right: With improvedinput representation, GNNs extrapolate better for the n -body problem. path 4regular ladder cycle expander complete tree general 94.512.511.06.40.10.10.00.0 (a) Max degree with max-pooling GNN. complete expander general 4regular ladder cycle tree path 8.92.40.00.4 11.613.8 19.3 33.6 (b) Shortest path with min-pooling GNN. Figure 7:
Importance of the training graph structure.
Rows indicate the graph structure covered by thetraining set and the extrapolation error (MAPE). In max degree, GNNs with max readout extrapolate well ifthe max/min degrees of the training graphs are not restricted (Theorem 9). In shortest path, the extrapolationerrors of min GNNs follow a U-shape in the sparsity of the training graphs. More results may be found inAppendix D.2.
Above, we saw that extrapolation in non-linear tasks is hard for MLPs (Theorem 3). Despite this limita-tion, GNNs have been shown to extrapolate well in some non-linear algorithmic tasks, such as intuitivephysics [Battaglia et al., 2016, Sanchez-Gonzalez et al., 2018], graph algorithms [Battaglia et al., 2018,Velickovic et al., 2020], and symbolic mathematics [Lample and Charton, 2020]. To address this discrepancy,we build on our MLP results and study how GNNs trained by GD extrapolate.
We begin with an example: training GNNs to solve the shortest path problem. For this task, prior worksobserve that a modified GNN architecture with min-aggregation can generalize to graphs larger than those inthe training set [Battaglia et al., 2018, Velickovic et al., 2020]: h ( k ) u = min v ∈N ( u ) MLP ( k ) (cid:0) h ( k − u , h ( k − v , w ( v,u ) (cid:1) . (2)We first provide an intuitive explanation (Fig 2a). Shortest path can be solved by the Bellman-Ford (BF)algorithm [Bellman, 1958] with the following update: d [ k ][ u ] = min v ∈N ( u ) d [ k − v ] + w ( v, u ) , (3)7here w ( v, u ) is the weight of edge ( v, u ) , and d [ k ][ u ] is the shortest distance to node u within k steps. Thetwo equations are similar: GNNs can simulate the BF algorithm if the MLP modules learn a linear function d [ k − v ] + w ( v, u ) . Since MLPs can extrapolate well in linear tasks (Theorem 5), this “alignment” mightexplain why min-aggregation GNNs can extrapolate well in this task.For comparison, we can reason why we would not expect GNNs with the more commonly used sum-aggregation (equation 1) to extrapolate well in this task. With sum-aggregation, the MLP modules need tolearn a non-linear function to simulate the BF algorithm, but Theorem 3 suggests that they will not extrapolatefor most nonlinearities outside the training support.We can extend the above intuition to other algorithmic tasks. Many target tasks where GNNs extrapolatewell can be solved by dynamic programming (DP) [Bellman, 1966], an algorithmic paradigm with a recursivestructure similar to GNNs’ (equation 1) [Xu et al., 2020]. Definition 6.
Dynamic programming (DP) is a recursive procedure with updatesAnswer [ k ][ s ] = DP-Update ( (cid:8) Answer [ k − s (cid:48) ] (cid:9) , s (cid:48) = 1 ...n ) , (4)where Answer [ k ][ s ] is the solution to a sub-problem indexed by iteration k and state s , and DP-Update is atask-specific update function that solves the sub-problem based on the previous iteration.Building on the extrapolation behavior of MLPs, we hypothesize that: given a DP task, if we can encodeappropriate non-linearity in the model architecture and input representations so that the MLP modules onlyneed to learn a linear step, then GNNs can extrapolate well. Hypothesis 7. (Linear algorithmic alignment). Let f : X → R be an algorithm and N a neural network with m MLP modules. Suppose there exist m linear functions { g i } mi =1 so that by replacing N ’s MLP moduleswith g i ’s, N simulates f . Given (cid:15) > , there exists { ( x i , f ( x i )) } ni =1 ⊂ D (cid:40) X so that N trained on { ( x i , f ( x i )) } ni =1 by GD with squared loss learns ˆ f with (cid:107) ˆ f − f (cid:107) < (cid:15) .Our hypothesis builds on the algorithmic alignment framework of [Xu et al., 2020], which suggests thatGNNs can interpolate well if MLP modules are “aligned” to easy-to-learn (possibly non-linear) functions.Successful extrapolation is harder: MLP modules need to align with linear functions.To satisfy the linear algorithmic alignment assumption, we can encode appropriate non-linear operationsin either the architecture or input representation (Fig. 2). Previous works that show successful extrapolationindeed use specialized architectures [Velickovic et al., 2020] or input representations [Lample and Charton,2020], and other works find the commonly-used sum-based GNNs do not extrapolate well [Santoro et al.,2018, Saxton et al., 2019].The shortest path example shows one example of encoding non-linearity in the architecture. Previousworks also encode log-and-exp transforms in the architecture to help extrapolate multiplication in arithmetictasks [Trask et al., 2018, Madsen and Johansen, 2020]. Neural symbolic programs improve extrapolation byexecuting symbolic operations encoded in a library [Johnson et al., 2017, Mao et al., 2019, Yi et al., 2018].For some tasks, it may be easier to change the input representation (Fig. 2b). Sometimes, we candecompose the target function f as f = g ◦ h into an embedding h and a “simpler” target function g thatour model can extrapolate well. If we can identify h from domain knowledge, then the model only needs tolearn g [Corso et al., 2020, Lample and Charton, 2020, Zhang et al., 2019]. Alternatively, h may be obtainedvia representation learning with unlabeled out-of-distribution data from X \ D [Chen et al., 2020, Devlinet al., 2019, Hu et al., 2020, Peters et al., 2018], which might explain why pre-trained representations such asBERT can improve out-of-distribution robustness [Hendrycks et al., 2020].Linear algorithmic alignment explains successful extrapolation in the literature and suggests that extrapo-lation is hard in general: encoding appropriate non-linearity often requires domain expertise and/or modelsearch. Next, we provide theoretical and empirical support for the linear algorithmic alignment hypothesis.While we focus on GNNs, our insights apply to other networks too.8 .2 Theoretical and Empirical Support
We validate our hypothesis on three DP tasks: max degree, shortest path and n -body problem (Fig. 6). Weprove the hypothesis for max degree, and highlight the role of graph structures in extrapolation. Theoretical analysis.
We start with a simple yet fundamental task: learning the max degree of a graph, aspecial case of DP with one iteration. As a corollary of Theorem 3, the commonly used sum-based GNN(equation 1) cannot extrapolate well (proof in Appendix B.4).
Corollary 8.
GNNs with sum-aggregation and sum-readout do not extrapolate well in Max Degree.
To achieve linear algorithmic alignment, we can encode the only non-linearity, the max function, in thereadout. Theorem 9 confirms that a GNN with max-readout can extrapolate well in this task.
Theorem 9.
Assume all nodes have the same feature. Let g and g (cid:48) be the max/min degree function, respectively.Let { ( G i , g ( G i ) } ni =1 be the training set. If { ( g ( G i ) , g (cid:48) ( G i ) , g ( G i ) · N max i , g (cid:48) ( G i ) · N min i ) } ni =1 spans R , where N max i and N min i are the number of nodes that have max/min degree on G i , then a one-layer max-readoutGNN trained on { ( G i , g ( G i )) } ni =1 with squared loss in the NTK regime learns g . Theorem 9 does not follow immediately from Theorem 5, because MLP modules in GNNs only receiveindirect supervision. We analyze the
Graph NTK [Du et al., 2019b] to prove Theorem 9 in Appendix B.5.While Theorem 9 assumes identical node features, we empirically observe similar results for both identicaland non-identical features (Fig. 16 in Appendix).
Interpretation of conditions.
The condition in Theorem 9 is analogous to that in Theorem 5. Boththeorems require diverse training data, measured by graph structure in Theorem 9 or directions in Theorem 5.In Theorem 9, the condition is violated if all training graphs have the same max or min node degrees, e.g.,when training data are from one of the following families: path, regular graphs with degree C ( C -regular),cycle, and ladder. Experiments: architectures that help extrapolation.
We validate our theoretical analysis with two DPtasks: max degree and shortest path (details in Appendix C.5 and C.6). While previous works only test ongraphs with different sizes [Battaglia et al., 2018, Velickovic et al., 2020], we also test on graphs with unseenstructure, edge weights and node features. The results support our theory. For max degree, GNNs withmax-readout are better than GNNs with sum-readout (Fig. 6a), confirming Corollary 8 and Theorem 9. Forshortest path, GNNs with min-readout and min-aggregation are better than GNNs with sum-readout (Fig. 6a).Experiments confirm the importance of training graphs structure (Fig. 7). Interestingly, the two tasksfavor different graph structures. For max degree, as Theorem 9 predicts, GNNs extrapolate well when trainedon trees, complete graphs, expanders, and general graphs, and extrapolation errors are higher when trained on4-regular, cycles, or ladder graphs. For shortest path, extrapolation errors follow a U-shaped curve as wechange the sparsity of training graphs (Fig. 7b and Fig. 18 in Appendix). Intuitively, models trained on sparseor dense graphs are more likely to learn degenerative solutions.
Experiments: representations that help extrapolation.
Finally, we show a good input representation alsohelps extrapolation. We study the n -body problem [Battaglia et al., 2016, Watters et al., 2017] (Appendix C.7),predicting the time evolution of n objects in a gravitational system. Following previous work, the inputis a complete graph where the nodes are the objects [Battaglia et al., 2016]. The node feature for u is theconcatenation of the object’s mass m u , position x ( t ) u , and velocity v ( t ) u at time t . The edge features are set tozero. We train GNNs to predict the velocity of each object u at time t + 1 . The true velocity f ( G ; u ) forobject u is approximately f ( G ; u ) ≈ v tu + a tu · dt, a tu = C · (cid:88) v (cid:54) = u m v (cid:107) x tu − x tv (cid:107) · (cid:16) x tv − x tu (cid:17) , (5)9here C is a constant. To learn f , the MLP modules need to learn a non-linear function. Therefore, GNNsdo extrapolate well to unseen masses or distances (“original features” in Fig. 6b). To extrapolate well inthis task, we use an improved representation h ( G ) to encode non-linearity. At time t , we transform the edgefeatures of ( u, v ) from zero to w ( t )( u,v ) = m v · (cid:0) x ( t ) v − x ( t ) u (cid:1) / (cid:107) x ( t ) u − x ( t ) v (cid:107) . The new edge features do notadd information, but the MLP modules now only need to learn linear functions, which helps extrapolation(“improved features” in Fig. 6b). We discuss several related settings. Intuitively, from the viewpoint of our results above, previous methods mayimprove extrapolation by 1) learning useful non-linearities beyond the training data range and 2) mappingrelevant test data to the training data range.
Domain adaptation studies generalization to a specific target domain [Ben-David et al., 2010, Blitzeret al., 2008, Mansour et al., 2009]. Typical strategies adjust the training process: for instance, use unlabeledsamples from the target domain to align the target and source distributions [Ganin et al., 2016, Long et al.,2015, Zhao et al., 2018]. Using target domain data during training may induce useful non-linearities on thetarget data. Moreover, we might avoid extrapolation by matching the target and source distributions, thoughthe correctness of the learned mapping depends on the label distribution [Zhao et al., 2019].
Self-supervised learning on a large amount of unlabeled data can learn useful non-linearities beyondthe labeled training data range [Devlin et al., 2019, Peters et al., 2018, Chen et al., 2020]. Hence, ourresults suggest an explanation why pre-trained representations such as BERT improve out-of-distributionperformance [Hendrycks et al., 2020]. In addition, self-supervised learning could map semantically similardata to similar representations, so some out-of-domain examples might fall inside the training distributionafter the mapping.
Invariant models take a causality perspective and learn stable features that respect specific invariancesacross multiple training distributions [Arjovsky et al., 2019, Muandet et al., 2013, Rojas-Carulla et al., 2018].If the model learns these invariances, this may essentially increase the training data range, since variations inthe invariant features may be ignored by the model.
Distributional robustness considers small adversarial perturbations of the data distribution, and ensuresthat the model performs well under these [Goh and Sim, 2010, Sagawa et al., 2020, Sinha et al., 2018, Staiband Jegelka, 2019]. We instead look at more global perturbations. Still, one would expect that modificationsthat help extrapolation in general also improve robustness to local perturbations.
This paper is an initial step towards formally understanding how neural networks trained by gradient descentextrapolate. We identify conditions under which MLPs and GNNs extrapolate as desired. We also suggest anexplanation how GNNs have been able to extrapolate well in complex algorithmic tasks: encoding appropriatenon-linearity in architecture and features can help extrapolation. Our results and hypothesis agree withempirical results, in this paper and in the literature. 10 cknowledgments
We thank Ruosong Wang, Tianle Cai, Han Zhao, Yuichi Yoshida, Takuya Konishi, Toru Lin, Weihua Hu,Matt J. Staib, Tianyi Yang, and Dingli (Leo) Yu for insightful discussions. This research was supported byNSF CAREER award 1553284, NSF III 1900933, and a Chevron-MIT Energy Fellowship. This researchwas also supported by JST ERATO JPMJER1201 and JSPS Kakenhi JP18H05291. MZ was supported byODNI, IARPA, via the BETTER Program contract 2019-19051600005. The views, opinions, and/or findingscontained in this article are those of the author and should not be interpreted as representing the official viewsor policies, either expressed or implied, of the Defense Advanced Research Projects Agency, the Departmentof Defense, ODNI, IARPA, or the U.S. Government. The U.S. Government is authorized to reproduce anddistribute reprints for governmental purposes notwithstanding any copyright annotation therein.
References
Zeyuan Allen-Zhu, Yuanzhi Li, and Yingyu Liang. Learning and generalization in overparameterizedneural networks, going beyond two layers. In
Advances in Neural Information Processing Systems , pages6155–6166, 2019a.Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In
International Conference on Machine Learning , pages 242–252, 2019b.Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization. arXivpreprint arXiv:1907.02893 , 2019.Raman Arora, Amitabh Basu, Poorya Mianjy, and Anirbit Mukherjee. Understanding deep neural networkswith rectified linear units. In
International Conference on Learning Representations , 2018.Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of optimizationand generalization for overparameterized two-layer neural networks. In
International Conference onMachine Learning , pages 322–332, 2019a.Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, Russ R Salakhutdinov, and Ruosong Wang. On exactcomputation with an infinitely wide neural net. In
Advances in Neural Information Processing Systems ,pages 8139–8148, 2019b.Sanjeev Arora, Simon S. Du, Zhiyuan Li, Ruslan Salakhutdinov, Ruosong Wang, and Dingli Yu. Harnessingthe power of infinitely wide deep nets on small-data tasks. In
International Conference on LearningRepresentations , 2020.Jimmy Ba, Murat Erdogdu, Taiji Suzuki, Denny Wu, and Tianzong Zhang. Generalization of two-layer neuralnetworks: An asymptotic viewpoint. In
International Conference on Learning Representations , 2020.Etienne Barnard and LFA Wessels. Extrapolation and interpolation in neural network classifiers.
IEEEControl Systems Magazine , 12(5):50–53, 1992.Peter Battaglia, Razvan Pascanu, Matthew Lai, Danilo Jimenez Rezende, et al. Interaction networks forlearning about objects, relations and physics. In
Advances in Neural Information Processing Systems ,pages 4502–4510, 2016.Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, MateuszMalinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, et al. Relational inductivebiases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261 , 2018.11ichard Bellman. On a routing problem.
Quarterly of applied mathematics , 16(1):87–90, 1958.Richard Bellman. Dynamic programming.
Science , 153(3731):34–37, 1966.Shai Ben-David, John Blitzer, Koby Crammer, Alex Kulesza, Fernando Pereira, and Jennifer WortmanVaughan. A theory of learning from different domains.
Machine learning , 79(1-2):151–175, 2010.Alberto Bietti and Julien Mairal. On the inductive bias of neural tangent kernels. In
Advances in NeuralInformation Processing Systems , pages 12873–12884, 2019.John Blitzer, Koby Crammer, Alex Kulesza, Fernando Pereira, and Jennifer Wortman. Learning bounds fordomain adaptation. In
Advances in neural information processing systems , pages 129–136, 2008.Yuan Cao and Quanquan Gu. Generalization bounds of stochastic gradient descent for wide and deep neuralnetworks. In
Advances in Neural Information Processing Systems , pages 10835–10845, 2019.Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastivelearning of visual representations. In
International Conference on Machine Learning , 2020.Lenaic Chizat and Francis Bach. A note on lazy training in supervised differentiable programming. arXivpreprint arXiv:1812.07956 , 8, 2018.Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. In
Advances in Neural Information Processing Systems , pages 2933–2943, 2019.Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Liò, and Petar Veliˇckovi´c. Principal neighbourhoodaggregation for graph nets.
Advances in Neural Information Processing Systems , 2020.G. Cybenko. Approximation by superpositions of a sigmoidal function.
Mathematics of control, signals andsystems , 2(4):303–314, 1989.Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectionaltransformers for language understanding. In
Proceedings of the 2019 Conference of the North AmericanChapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1(Long and Short Papers) , pages 4171–4186, 2019.Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima ofdeep neural networks. In
International Conference on Machine Learning , pages 1675–1685, 2019a.Simon S. Du and Jason D. Lee. On the power of over-parametrization in neural networks with quadraticactivation. In
International Conference on Machine Learning , 2018.Simon S Du, Kangcheng Hou, Russ R Salakhutdinov, Barnabas Poczos, Ruosong Wang, and Keyulu Xu.Graph neural tangent kernel: Fusing graph neural networks with graph kernels. In
Advances in NeuralInformation Processing Systems , pages 5724–5734, 2019b.Simon S. Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. In
International Conference on Learning Representations , 2019c.K. Funahashi. On the approximate realization of continuous mappings by neural networks.
Neural networks ,2(3):183–192, 1989.Yaroslav Ganin, Evgeniya Ustinova, Hana Ajakan, Pascal Germain, Hugo Larochelle, François Laviolette,Mario Marchand, and Victor Lempitsky. Domain-adversarial training of neural networks.
The Journal ofMachine Learning Research , 17(1):2096–2030, 2016.12ehrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Linearized two-layers neuralnetworks in high dimension. arXiv preprint arXiv:1904.12191 , 2019.Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural messagepassing for quantum chemistry. In
International Conference on Machine Learning , pages 1273–1272,2017.Joel Goh and Melvyn Sim. Distributionally robust optimization and its tractable approximations.
Operationsresearch , 58(4-part-1):902–917, 2010.Pamela J Haley and DONALD Soloway. Extrapolation limitations of multilayer feedforward neural networks.In
International Joint Conference on Neural Networks , volume 4, pages 25–30. IEEE, 1992.Boris Hanin and David Rolnick. Complexity of linear regions in deep networks. In
International Conferenceon Machine Learning , pages 2596–2604, 2019.Moritz Hardt, Ben Recht, and Yoram Singer. Train faster, generalize better: Stability of stochastic gradientdescent. In
International Conference on Machine Learning , pages 1225–1234, 2016.Matthias Hein, Maksym Andriushchenko, and Julian Bitterwolf. Why relu networks yield high-confidencepredictions far away from the training data and how to mitigate the problem. In
Proceedings of the IEEEConference on Computer Vision and Pattern Recognition , pages 41–50, 2019.Dan Hendrycks, Xiaoyuan Liu, Eric Wallace, Adam Dziedzic, Rishabh Krishnan, and Dawn Song. Pretrainedtransformers improve out-of-distribution robustness. In
Association for Computational Linguistics , 2020.Kurt Hornik, Maxwell Stinchcombe, and Halbert White. Multilayer feedforward networks are universalapproximators.
Neural networks , 2(5):359–366, 1989.Weihua Hu, Bowen Liu, Joseph Gomes, Marinka Zitnik, Percy Liang, Vijay Pande, and Jure Leskovec.Strategies for pre-training graph neural networks. In
International Conference on Learning Representations ,2020.Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalizationin neural networks. In
Advances in neural information processing systems , pages 8571–8580, 2018.Justin Johnson, Bharath Hariharan, Laurens van der Maaten, Judy Hoffman, Li Fei-Fei, C Lawrence Zitnick,and Ross Girshick. Inferring and executing programs for visual reasoning. In
Proceedings of the IEEEInternational Conference on Computer Vision , pages 2989–2998, 2017.V. Kurková. Kolmogorov’s theorem and multilayer neural networks.
Neural networks , 5(3):501–506, 1992.Guillaume Lample and François Charton. Deep learning for symbolic mathematics. In
InternationalConference on Learning Representations , 2020.Alan Lapedes and Robert Farber. Nonlinear signal processing using neural networks: Prediction and systemmodelling. Technical report, 1987.Jaehoon Lee, Lechao Xiao, Samuel Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, andJeffrey Pennington. Wide neural networks of any depth evolve as linear models under gradient descent. In
Advances in neural information processing systems , pages 8570–8581, 2019.Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient descenton structured data. In
Advances in Neural Information Processing Systems , pages 8157–8166, 2018.13uanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization effect of initial large learningrate in training neural networks. In
Advances in Neural Information Processing Systems , pages 11669–11680, 2019.Roi Livni, Shai Shalev-Shwartz, and Ohad Shamir. On the computational efficiency of training neuralnetworks. In
Advances in neural information processing systems , pages 855–863, 2014.Mingsheng Long, Yue Cao, Jianmin Wang, and Michael Jordan. Learning transferable features with deepadaptation networks. In
International conference on machine learning , pages 97–105. PMLR, 2015.Andreas Madsen and Alexander Rosenberg Johansen. Neural arithmetic units. In
International Conferenceon Learning Representations , 2020.Hartmut Maennel, Olivier Bousquet, and Sylvain Gelly. Gradient Descent Quantizes ReLU Network Features. arXiv e-prints , art. arXiv:1803.08367, March 2018.Yishay Mansour, Mehryar Mohri, and Afshin Rostamizadeh. Domain adaptation: Learning bounds andalgorithms. In
Conference on Learning Theory , 2009.Jiayuan Mao, Chuang Gan, Pushmeet Kohli, Joshua B. Tenenbaum, and Jiajun Wu. The neuro-symbolicconcept learner: Interpreting scenes, words, and sentences from natural supervision. In
InternationalConference on Learning Representations , 2019.David B McCaughan. On the properties of periodic perceptrons. In
International Conference on NeuralNetworks , 1997.Krikamol Muandet, David Balduzzi, and Bernhard Schölkopf. Domain generalization via invariant featurerepresentation. In
International Conference on Machine Learning , pages 10–18, 2013.Roman Novak, Lechao Xiao, Jiri Hron, Jaehoon Lee, Alexander A. Alemi, Jascha Sohl-Dickstein, andSamuel S. Schoenholz. Neural tangents: Fast and easy infinite neural networks in python. In
InternationalConference on Learning Representations , 2020.Jonas Peters, Peter Bühlmann, and Nicolai Meinshausen. Causal inference by using invariant prediction:identification and confidence intervals.
Journal of the Royal Statistical Society: Series B (StatisticalMethodology) , 78(5):947–1012, 2016.Matthew Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, and LukeZettlemoyer. Deep contextualized word representations. In
Proceedings of the 2018 Conference of theNorth American Chapter of the Association for Computational Linguistics: Human Language Technologies,Volume 1 (Long Papers) , pages 2227–2237, 2018.Mateo Rojas-Carulla, Bernhard Schölkopf, Richard Turner, and Jonas Peters. Invariant models for causaltransfer learning.
The Journal of Machine Learning Research , 19(1):1309–1342, 2018.Shiori Sagawa, Pang Wei Koh, Tatsunori B. Hashimoto, and Percy Liang. Distributionally robust neuralnetworks. In
International Conference on Learning Representations , 2020.Alvaro Sanchez-Gonzalez, Nicolas Heess, Jost Tobias Springenberg, Josh Merel, Martin Riedmiller, RaiaHadsell, and Peter Battaglia. Graph networks as learnable physics engines for inference and control. In
International Conference on Machine Learning , pages 4467–4476, 2018.Adam Santoro, Felix Hill, David Barrett, Ari Morcos, and Timothy Lillicrap. Measuring abstract reasoningin neural networks. In
International Conference on Machine Learning , pages 4477–4486, 2018.14edro Savarese, Itay Evron, Daniel Soudry, and Nathan Srebro. How do infinite width bounded normnetworks look in function space? In
Conference on Learning Theory (COLT) , 2019.David Saxton, Edward Grefenstette, Felix Hill, and Pushmeet Kohli. Analysing mathematical reasoningabilities of neural models. In
International Conference on Learning Representations , 2019.Franco Scarselli, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. The graphneural network model.
IEEE Transactions on Neural Networks , 20(1):61–80, 2009.Aman Sinha, Hongseok Namkoong, and John Duchi. Certifying some distributional robustness with principledadversarial training. In
International Conference on Learning Representations , 2018.Mei Song, Andrea Montanari, and P Nguyen. A mean field view of the landscape of two-layers neuralnetworks.
Proceedings of the National Academy of Sciences , 115:E7665–E7671, 2018.JM Sopena and R Alquezar. Improvement of learning in recurrent networks by substituting the sigmoidactivation function. In
International Conference on Artificial Neural Networks , pages 417–420. Springer,1994.Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan Srebro. The implicit bias ofgradient descent on separable data.
The Journal of Machine Learning Research , 19(1):2822–2878, 2018.Matthew Staib and Stefanie Jegelka. Distributionally robust optimization and generalization in kernel methods.In
Advances in Neural Information Processing Systems , pages 9134–9144, 2019.Andrew Trask, Felix Hill, Scott E Reed, Jack Rae, Chris Dyer, and Phil Blunsom. Neural arithmetic logicunits. In
Advances in Neural Information Processing Systems , pages 8035–8044, 2018.Leslie G Valiant. A theory of the learnable. In
Proceedings of the sixteenth annual ACM symposium onTheory of computing , pages 436–445. ACM, 1984.Vladimir Vapnik.
The nature of statistical learning theory . Springer science & business media, 2013.Petar Velickovic, Rex Ying, Matilde Padovano, Raia Hadsell, and Charles Blundell. Neural execution ofgraph algorithms. In
International Conference on Learning Representations , 2020.Nicholas Watters, Daniel Zoran, Theophane Weber, Peter Battaglia, Razvan Pascanu, and Andrea Tacchetti.Visual interaction networks: Learning a physics simulator from video. In
Advances in neural informationprocessing systems , pages 4539–4547, 2017.Francis Williams, Matthew Trager, Daniele Panozzo, Claudio Silva, Denis Zorin, and Joan Bruna. Gradientdynamics of shallow univariate relu networks. In
Advances in Neural Information Processing Systems ,pages 8376–8385, 2019.Keyulu Xu, Chengtao Li, Yonglong Tian, Tomohiro Sonobe, Ken-ichi Kawarabayashi, and Stefanie Jegelka.Representation learning on graphs with jumping knowledge networks. In
International Conference onMachine Learning , pages 5453–5462, 2018.Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In
International Conference on Learning Representations , 2019.Keyulu Xu, Jingling Li, Mozhi Zhang, Simon S. Du, Ken ichi Kawarabayashi, and Stefanie Jegelka. Whatcan neural networks reason about? In
International Conference on Learning Representations , 2020. URL https://openreview.net/forum?id=rJxbJeHFPS .15exin Yi, Jiajun Wu, Chuang Gan, Antonio Torralba, Pushmeet Kohli, and Josh Tenenbaum. Neural-symbolicvqa: Disentangling reasoning from vision and language understanding. In
Advances in Neural InformationProcessing Systems , pages 1031–1042, 2018.Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deeplearning requires rethinking generalization. In
International Conference on Learning Representations ,2017.Mozhi Zhang, Keyulu Xu, Ken-ichi Kawarabayashi, Stefanie Jegelka, and Jordan Boyd-Graber. Are girlsneko or sh¯ojo? cross-lingual alignment of non-isomorphic embeddings with iterative normalization.In
Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics , pages3180–3189, 2019.Han Zhao, Shanghang Zhang, Guanhang Wu, José MF Moura, Joao P Costeira, and Geoffrey J Gordon.Adversarial multiple source domain adaptation. In
Advances in neural information processing systems ,pages 8559–8570, 2018.Han Zhao, Remi Tachet Des Combes, Kun Zhang, and Geoffrey Gordon. On learning invariant representationsfor domain adaptation. In
International Conference on Machine Learning , pages 7523–7532, 2019.16
Theoretical Background
In this section we introduce theoretical background on neural tangent kernel (NTK), which draws anequivalence between the training dynamics of infinitely-wide (or ultra-wide) neural networks and that of akernel regression with respect to neural tangent kernel.Consider a general neural network f ( θ , x ) : X → R where θ ∈ R m is the parameters in the networkand x ∈ X is the input. Suppose we train the neural network by minimizing the mean squared loss overtraining data, (cid:96) ( θ ) = (cid:80) ni =1 ( f ( θ , x i ) − y i ) , by gradient descent with infinitesimally small learning rate,i.e., d θ ( t ) dt = −∇ (cid:96) ( θ ( t )) . Let u ( t ) = ( f ( θ ( t ) , x i )) ni =1 be the network outputs. u ( t ) follows the dynamics d u ( t ) dt = − H ( t )( u ( t ) − y ) , where H ( t ) is an n × n matrix whose ( i, j ) -th entry is H ( t ) ij = (cid:28) ∂f ( θ ( t ) , x i ) ∂ θ , ∂f ( θ ( t ) , x j ) ∂ θ (cid:29) . A line of works show that for sufficiently wide networks, H ( t ) stays almost constant during training, i.e., H ( t ) = H (0) in the limit [Arora et al., 2019a,b, Allen-Zhu et al., 2019a, Du et al., 2019c,a, Li and Liang,2018, Jacot et al., 2018]. Suppose network parameters are randomly initialized, as network width goes toinfinity, H (0) converges to a fixed matrix, the neural tangent kernel (NTK) [Jacot et al., 2018]NTK ( x , x (cid:48) ) = E θ ∼W (cid:28) ∂f ( θ ( t ) , x ) ∂ θ , ∂f ( θ ( t ) , x (cid:48) ) ∂ θ (cid:29) , (6)where W is Gaussian.Therefore, the training dynamics of sufficiently wide neural networks in this regime is equivalent to thatof kernel regression with respect to NTK. This implies the function learned by a neural network given atraining set, denoted by f NTK ( x ) , can be precisely characterized, and is equivalent to the following kernelregression solution f NTK ( x ) = ( NTK ( x , x ) , ..., NTK ( x , x n )) · NTK − train Y , (7)where NTK train is the n × n kernel for training data, NTK ( x , x i ) is the kernel value between test data x andtraining data x i , and Y is training labels.We can in fact exactly calculate the neural tangent kernel matrix. Exact formula of NTK has beenderived for multi-layer perceptron (MLP), a.k.a. fully-connected networks [Jacot et al., 2018], convolutionalnetworks [Arora et al., 2019b], and Graph Neural Networks (GNN) [Du et al., 2019b].Our theory builds upon this equivalence of network learning and kernel regression to more preciselycharacterize the function learned by a sufficiently-wide neural network given a training set. In particular, thedifference between the learned function and true function over the domain of X determines the extrapolationerror.However, in general it is non-trivial to compute or analyze the functional form of what a neural networklearns using equation 7, because the kernel regression solution using neural tangent kernel only gives point-wise evaluation. Thus, we instead analyze the function learned by a network in the NTK’s induced featurespace , because representations in the feature space would give a functional form.Lemma 10 makes this connection more precise: the solution to the kernel regression using neural tangentkernel, which also equals over-parameterized network learning, is equivalent to a min-norm solution amongfunctions in the NTK’s induced feature space that fits all training data. Here the min-norm refers to the RKHSnorm. 17 emma 10. Let φ ( x ) be a feature map induced by a neural tangent kernel, for any x ∈ R d . The solution tokernel regression equation 7 is equivalent to f NTK ( x ) = φ ( x ) (cid:62) β NTK , where β NTK is min β (cid:107) β (cid:107) s.t. φ ( x i ) (cid:62) β = y i , for i = 1 , ..., n. We prove Lemma 10 in Appendix B.6. To analyze the learned functions as the min-norm solution infeature space, we also need the explicit formula of an induced feature map of the corresponding neuraltangent kernel.Next, we give a NTK feature space for MLPs with ReLU activation. It follows easily from the kernelformula described in Jacot et al. [2018], Arora et al. [2019b], Bietti and Mairal [2019].
Lemma 11.
An infinite-dimensional feature map φ ( x ) induced by the neural tangent kernel of a two-layermulti-layer perceptron with ReLU activation function is φ ( x ) = c (cid:16) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , w ( k ) (cid:62) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , ... (cid:17) , (8) where w ( k ) ∼ N ( , I ) , with k going to infinity. c is a constant, and I is the indicator function. We prove Lemma 11 in Appendix B.7. The feature maps for other architectures, e.g., Graph NeuralNetworks (GNNs) can be derived similarly. We analyze the Graph Neural Tangent Kernel (GNTK) for asimple GNN architecture in Theorem 9.We then use Lemma 10 and 11 to characterize the properties of functions learned by an over-parameterizedneural network. We precisely characterize the neural networks’ learned functions in the NTK regime viasolving the constrained optimization problem corresponding to the min-norm function in NTK feature spacewith the constraint of fitting the training data.However, there still remains many challenges for analyzing the solution to the min-norm solution in NTKspace. For example, provable extrapolation (exact or asymptotic) is often not achieved with most training datadistribution. Understanding the desirable condition requires significant insights into the geometry propertiesof training data distribution, and how they interact with the solution learned by neural networks. Our insightsand refined analysis shows in R d space, we need to consider the directions of training data. In graphs, weneed to consider, in addition, the graph structure of training data. We refer readers to detailed proofs forthe intuition of data conditions. Moreover, since NTK corresponds to infinitely wide neural networks, thefeature space is of infinite dimension. The analysis of infinite dimensional spaces poses non-trivial technicalchallenges too.Since different theorems have their respective challenges and insights/techniques, we refer the interestedreaders to the respective proofs for details. In Lemma 4 (proof in Appendix B.2), Theorem 5 (proof inAppendix B.3), and Theorem 3 (proof in Appendix B.1) we analyze over-parameterized MLPs. The proofof Corollary 8 is in Appendix B.4. In Theorem 9 we analyze Graph Neural Networks (GNNs) (proof inAppendix B.5). B Proofs of All Theorems and Lemmas
B.1 Proof of Theorem 3
To show neural network outputs f ( x ) converge to a linear function along all directions v , we will analyzethe function learned by a neural network on the training set { ( x i , y i ) } ni =1 , by studying the functionalrepresentation in the network’s neural tangent kernel RKHS space.18ecall from Section A that in the NTK regime, i.e., networks are infinitely wide, randomly initialized,and trained by gradient descent with infinitesimally small learning rate, the learning dynamics of the neuralnetwork is equivalent to that of a kernel regression with respect to its neural tangent kernel.For any x ∈ R d , the network output is given by f ( x ) = (cid:0)(cid:10) φ ( x ) , φ ( x ) (cid:11) , ..., (cid:10) φ ( x ) , φ ( x n ) (cid:11)(cid:1) · NTK − train Y , where NTK train is the n × n kernel for training data, (cid:10) φ ( x ) , φ ( x i ) (cid:11) is the kernel value between test data x and training data x i , and Y is training labels. By Lemma 10, the kernel regression solution is also equivalentto the min-norm solution in the NTK RKHS space that fits all training data f ( x ) = φ ( x ) (cid:62) β NTK , (9)where the representation coefficient β NTK is min β (cid:107) β (cid:107) s.t. φ ( x i ) (cid:62) β = y i , for i = 1 , ..., n. The feature map φ ( x ) for a two-layer MLP with ReLU activation is given by Lemma 11 φ ( x ) = c (cid:48) (cid:16) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , w ( k ) (cid:62) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , ... (cid:17) , (10)where w ( k ) ∼ N ( , I ) , with k going to infinity. c (cid:48) is a constant, and I is the indicator function. Without lossof generality, we assume the bias term to be . For simplicity of notations, we denote each data x plus biasterm by, i.e., ˆ x = [ x | [Bietti and Mairal, 2019], and assume constant term is .Given any direction v on the unit sphere, the network outputs for out-of-distribution data x = t v and x = x + h v = (1 + λ ) x , where we introduce the notation of x and λ for convenience, are givenby equation 9 and equation 10 f ( ˆ x ) = β (cid:62) NTK (cid:16) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) , w ( k ) (cid:62) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) , ... (cid:17) ,f (ˆ x ) = β (cid:62) NTK (cid:16) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) , w ( k ) (cid:62) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) , ... (cid:17) , where we have ˆ x = [ x | and ˆ x = [(1 + λ ) x | . It follows that f (ˆ x ) − f ( ˆ x ) = β (cid:62) NTK (cid:16) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) − ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) , (11) w ( k ) (cid:62) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) − w ( k ) (cid:62) ˆ x · I (cid:16) w ( k ) (cid:62) ˆ x ≥ (cid:17) , ... (cid:17) (12)By re-arranging the terms, we get the following equivalent form of the entries: ˆ x · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − ˆ x · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (13) = ˆ x · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17) + I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) − ˆ x · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (14) = ˆ x · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) + (ˆ x − ˆ x ) · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (15) = [ x | · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) + [ h v | · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (16)19imilarly, we have w (cid:62) ˆ x · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − w (cid:62) ˆ x · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (17) = w (cid:62) ˆ x · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17) + I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) − w (cid:62) ˆ x · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (18) = w (cid:62) ˆ x · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) + w (cid:62) (ˆ x − ˆ x ) · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (19) = w (cid:62) [ x | · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) + w (cid:62) [ h v | · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) (20)Again, let us denote the part of β NTK corresponding to each w by β w . Moreover, let us denote the partcorresponding to equation 16 by β w and the part corresponding to equation 20 by β w . Then we have f (ˆ x ) − f ( ˆ x ) h (21) = (cid:90) β (cid:62) w [ x /h | /h ] · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) d P ( w ) (22) + (cid:90) β (cid:62) w [ v | · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) d P ( w ) (23) + (cid:90) β w · w (cid:62) [ x /h | /h ] · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) d P ( w ) (24) + (cid:90) β w · w (cid:62) [ v | · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) d P ( w ) (25)Note that all β w are finite constants that depend on the training data. Next, we show that as t → ∞ , each ofthe terms above converges in O (1 /(cid:15) ) to some constant coefficient β v that depend on the training data and thedirection v . Let us first consider equation 23. We have (cid:90) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) d P ( w ) = (cid:90) I (cid:16) w (cid:62) [ x | ≥ (cid:17) d P ( w ) (26) = (cid:90) I (cid:16) w (cid:62) [ x /t | /t ] ≥ (cid:17) d P ( w ) (27) −→ (cid:90) I (cid:16) w (cid:62) [ v | ≥ (cid:17) d P ( w ) as t → ∞ (28)Because β w are finite constants, it follows that (cid:90) β (cid:62) w [ v | · I (cid:16) w (cid:62) ˆ x ≥ (cid:17) d P ( w ) → (cid:90) β (cid:62) w [ v | · I (cid:16) w (cid:62) [ v | ≥ (cid:17) d P ( w ) , (29)where the right hand side is a constant that depends on training data and direction v . Next, we show theconvergence rate for equation 29. Given error (cid:15) > , because β (cid:62) w [ v | are finite constants, we need to boundthe following by C · (cid:15) for some constant C , | (cid:90) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) [ v | ≥ (cid:17) d P ( w ) | (30) = | (cid:90) I (cid:16) w (cid:62) [ x | ≥ (cid:17) − I (cid:16) w (cid:62) [ x | ≥ (cid:17) d P ( w ) | (31)Observe that the two terms in equation 31 represent the volume of half-(balls) that are orthogonal to vectors [ x | and [ x | . Hence, equation 31 is the volume of the non-overlapping part of the two (half)balls, which20s created by rotating an angle θ along the last coordinate. By symmetry, equation 31 is linear in θ . Moreover,the angle θ = arctan( C/t ) for some constant C . Hence, it follows that | (cid:90) I (cid:16) w (cid:62) [ x | ≥ (cid:17) − I (cid:16) w (cid:62) [ x | ≥ (cid:17) d P ( w ) | = C · arctan( C /t ) (32) ≤ C · C /t (33) = O (1 /t ) (34)In the last inequality, we used the fact that arctan x < x for x > . Hence, O (1 /t ) < (cid:15) implies t = O (1 /(cid:15) ) as desired. Next, we consider equation 22. (cid:90) β (cid:62) w [ x /h | /h ] · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) d P ( w ) (35)Let us first analyze the convergence of the following: | (cid:90) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17) d P ( w ) | (36) = | (cid:90) I (cid:16) w (cid:62) [(1 + λ ) x | ≥ (cid:17) − I (cid:16) w (cid:62) [ x | ≥ (cid:17) d P ( w )d P ( w ) | (37) = | (cid:90) I (cid:18) w (cid:62) [ x |
11 + λ ] ≥ (cid:19) − I (cid:16) w (cid:62) [ x | ≥ (cid:17) d P ( w )d P ( w ) | → (38)The convergence to follows from equation 32. Now we consider the convergence rate. The angle θ is atmost − λ times of that in equation 32. Hence, the rate is as follows (cid:18) −
11 + λ (cid:19) · O (cid:18) t (cid:19) = λ λ · O (cid:18) t (cid:19) = h/t h/t · O (cid:18) t (cid:19) = O (cid:18) h ( h + t ) t (cid:19) (39)Now we get back to equation 22, which simplifies as the following. (cid:90) β (cid:62) w (cid:20) v + t v h | h (cid:21) · (cid:16) I (cid:16) w (cid:62) ˆ x ≥ (cid:17) − I (cid:16) w (cid:62) ˆ x ≥ (cid:17)(cid:17) d P ( w ) (40)We compare the rate of growth of left hand side and the rate of decrease of right hand side (indicators). th · h ( h + t ) t = 1 h + t → as t → ∞ (41) h · h ( h + t ) t = 1( h + t ) t → as t → ∞ (42)Hence, the indicators decrease faster, and it follows that equation 22 converges to with rate O ( (cid:15) ) . Moreover,we can bound w with standard concentration techniques. Then the proofs for equation 24 and equation 25follow similarly. This completes the proof. B.2 Proof of Lemma 4
Overview of proof.
To prove exact extrapolation given the conditions on training data, we analyze thefunction learned by the neural network in a functional form. The network’s learned function can be preciselycharacterized by a solution in the network’s neural tangent kernel feature space which has a minimum RKHSnorm among functions that can fit all training data, i.e., it corresponds to the optimum of a constrainedoptimization problem. We show that the global optimum of this constrained optimization problem, given theconditions on training data, is precisely the same function as the underlying true function.21 etup and preparation.
Let X = { x , ..., x n } and Y = { y , ..., y n } denote the training set input featuresand their labels. Let β g ∈ R d denote the true parameters/weights for the underlying linear function g , i.e., g ( x ) = β (cid:62) g x for all x ∈ R d Recall from Section A that in the NTK regime, where networks are infinitely wide, randomly initialized,and trained by gradient descent with infinitesimally small learning rate, the learning dynamics of a neuralnetwork is equivalent to that of a kernel regression with respect to its neural tangent kernel. Moreover,Lemma 10 tells us that this kernel regression solution can be expressed in the functional form in the neuraltangent kernel’s feature space. That is, the function learned by the neural network (in the ntk regime) can beprecisely characterized as f ( x ) = φ ( x ) (cid:62) β NTK , where the representation coefficient β NTK is min β (cid:107) β (cid:107) (43)s.t. φ ( x i ) (cid:62) β = y i , for i = 1 , ..., n. (44)An infinite-dimensional feature map φ ( x ) for a two-layer ReLU network is described in Lemma 11 φ ( x ) = c (cid:48) (cid:16) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , w ( k ) (cid:62) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , ... (cid:17) , where w ( k ) ∼ N ( , I ) , with k going to infinity. c (cid:48) is a constant, and I is the indicator function. That is, thereare infinitely many directions w with Gaussian density, and each direction comes with two features. Withoutloss of generality, we can assume the scaling constant to be . Constrained optimization in NTK feature space.
The representation or weight of the neural network’slearned function in the neural tangent kernel feature space, β NTK , consists of weight vectors for each x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) ∈ R d and w ( k ) (cid:62) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) ∈ R . For simplicity of notation, we will use w torefer to a particular w , without considering the index ( k ) , which does not matter for our purposes. For any w ∈ R d , we denote by ˆ β w = ( ˆ β (1) w , ..., ˆ β ( d ) w ) ∈ R d the weight vectors corresponding to x · I (cid:0) w (cid:62) x ≥ (cid:1) ,and denote by ˆ β (cid:48) w ∈ R d the weight for w (cid:62) x · I (cid:0) w (cid:62) x ≥ (cid:1) .Observe that for any w ∼ N ( , I ) ∈ R d , any other vectors in the same direction will activate the sameset of x i ∈ R d . That is, if w (cid:62) x i ≥ for any w ∈ R d , then ( k · w ) (cid:62) x i ≥ for any k > . Hence, we canreload our notation to combine the effect of weights for w ’s in the same direction. This enables simplernotations and allows us to change the distribution of w in NTK features from Gaussian distribution to uniformdistribution on the unit sphere.More precisely, we reload our notation by using β w and β (cid:48) w to denote the combined effect of all weights ( ˆ β (1) k w , ..., ˆ β ( d ) k w ) ∈ R d and ˆ β (cid:48) k w ∈ R for all k w with k > in the same direction of w . That is, for each w ∼ Uni(unit sphere) ∈ R d , we define β ( j ) w as the total effect of weights in the same direction β ( j ) w = (cid:90) ˆ β ( j ) u I (cid:18) w (cid:62) u (cid:107) w (cid:107) · (cid:107) u (cid:107) = 1 (cid:19) d P ( u ) , for j = [ d ] (45)where u ∼ N ( , I ) . Note that to ensure the β w is a well-defined number, here we can work with thepolar representation and integrate with respect to an angle. Then β w is well-defined. But for simplicity ofexposition, we use the plain notation of integral. Similarly, we define β (cid:48) w as reloading the notation of β (cid:48) w = (cid:90) ˆ β u I (cid:18) w (cid:62) u (cid:107) w (cid:107) · (cid:107) u (cid:107) = 1 (cid:19) · (cid:107) u (cid:107)(cid:107) w (cid:107) d P ( u ) (46)22ere, in equation 46 we have an extra term of (cid:107) u (cid:107)(cid:107) w (cid:107) compared to equation 45 because the NTK features thatequation 46 corresponds to, w (cid:62) x · I (cid:0) w (cid:62) x ≥ (cid:1) , has an extra w (cid:62) term. So we need to take into account thescaling. This abstraction enables us to make claims on the high-level parameters β w and β (cid:48) w only, which wewill show to be sufficient to determine the learned function.Then we can formulate the constrained optimization problem whose solution gives a functional form ofthe neural network’s learned function. We rewrite the min-norm solution in equation 43 as min β (cid:90) (cid:16) β (1) w (cid:17) + (cid:16) β (2) w (cid:17) + ... + (cid:16) β ( d ) w (cid:17) + (cid:0) β (cid:48) w (cid:1) d P ( w ) (47)s.t. (cid:90) w (cid:62) x i ≥ β (cid:62) w x i + β (cid:48) w · w (cid:62) x i d P ( w ) = β (cid:62) g x i ∀ i ∈ [ n ] , (48)where the density of w is now uniform on the unit sphere of R d . Observe that since w is from a uniformdistribution, the probability density function P ( w ) is a constant. This means every x i is activated by half ofthe w on the unit sphere, which implies we can now write the right hand side of equation 48 in the form ofleft hand side, i.e., integral form. This allows us to further simplify equation 48 as (cid:90) w (cid:62) x i ≥ (cid:16) β (cid:62) w + β (cid:48) w · w (cid:62) − · β (cid:62) g (cid:17) x i d P ( w ) = 0 ∀ i ∈ [ n ] , (49)where equation 49 follows from the following steps of simplification (cid:90) w (cid:62) x i ≥ β (1) w x (1) i + .. β ( d ) w x ( d ) i + β (cid:48) w · w (cid:62) x i d P ( w ) = β (1) g x (1) i + ... β ( d ) g x ( d ) i ∀ i ∈ [ n ] , ⇐⇒ (cid:90) w (cid:62) x i ≥ β (1) w x (1) i + ... + β ( d ) w x ( d ) i + β (cid:48) w · w (cid:62) x i d P ( w )= 1 (cid:82) w (cid:62) x i ≥ d P ( w ) · (cid:90) w (cid:62) x i ≥ d P ( w ) · (cid:16) β (1) g x (1) i + ... + β ( d ) g x ( d ) i (cid:17) ∀ i ∈ [ n ] , ⇐⇒ (cid:90) w (cid:62) x i ≥ β (1) w x (1) i + ... + β ( d ) w x ( d ) i + β (cid:48) w · w (cid:62) x i d P ( w )= 2 · (cid:90) w (cid:62) x i ≥ β (1) g x (1) i + ... + β ( d ) g x ( d ) i d P ( w ) ∀ i ∈ [ n ] , ⇐⇒ (cid:90) w (cid:62) x i ≥ (cid:16) β (cid:62) w + β (cid:48) w · w (cid:62) − · β (cid:62) g (cid:17) x i d P ( w ) = 0 ∀ i ∈ [ n ] . Claim 12.
Without loss of generality, assume the scaling factor c in NTK feature map φ ( x ) is . Then theglobal optimum to the constraint optimization problem equation 47 subject to equation 49, i.e., min β (cid:90) (cid:16) β (1) w (cid:17) + (cid:16) β (2) w (cid:17) + ... + (cid:16) β ( d ) w (cid:17) + (cid:0) β (cid:48) w (cid:1) d P ( w ) (50) s.t. (cid:90) w (cid:62) x i ≥ (cid:16) β (cid:62) w + β (cid:48) w · w (cid:62) − · β (cid:62) g (cid:17) x i d P ( w ) = 0 ∀ i ∈ [ n ] . (51) satisfies β w + β (cid:48) w · w = 2 β g for all w . f NTK ( x ) = g ( x ) . This is because, ifour claim holds, then for any x ∈ R d f NTK ( x ) = (cid:90) w (cid:62) x ≥ β (cid:62) w x + β (cid:48) w · w (cid:62) x d P ( w )= (cid:90) w (cid:62) x ≥ · β (cid:62) g x d P ( w )= (cid:90) w (cid:62) x ≥ d P ( w ) · β (cid:62) g x = 12 · β (cid:62) g x = g ( x ) Thus, it remains to prove Claim 12. To compute the optimum to the constrained optimization problemequation 50, we consider the Lagrange multipliers. It is clear that the objective equation 50 is convex.Moreover, the constraint equation 51 is affine. Hence, by KKT, solution that satisfies the Lagrange conditionwill be the global optimum. We compute the Lagrange multiplier as L ( β , λ ) = (cid:90) (cid:16) β (1) w (cid:17) + (cid:16) β (2) w (cid:17) + ... + (cid:16) β ( d ) w (cid:17) + (cid:0) β (cid:48) w (cid:1) d P ( w ) (52) − n (cid:88) i =1 λ i · (cid:90) w (cid:62) x i ≥ (cid:16) β (cid:62) w + β (cid:48) w · w (cid:62) − · β (cid:62) g (cid:17) x i d P ( w ) (53)Setting the partial derivative of L ( β , λ ) with respect to each variable to zero gives ∂ L ∂ β ( k ) w = 2 β ( k ) w P ( w ) + n (cid:88) i =1 λ i · x ( k ) i · I (cid:16) w (cid:62) x i ≥ (cid:17) = 0 (54) ∂ L β (cid:48) w = 2 β (cid:48) w P ( w ) + n (cid:88) i =1 λ i · w (cid:62) x i · I (cid:16) w (cid:62) x i ≥ (cid:17) = 0 (55) ∂ L ∂λ i = (cid:90) w (cid:62) x i ≥ (cid:16) β (cid:62) w + β (cid:48) w · w (cid:62) − · β (cid:62) g (cid:17) x i d P ( w ) = 0 (56)It is clear that the solution in Claim 12 immediately satisfies equation 56. Hence, it remains to show thereexist a set of λ i for i ∈ [ n ] that satisfies equation 54 and equation 55. We can simplify equation 54 as β ( k ) w = c · n (cid:88) i =1 λ i · x ( k ) i · I (cid:16) w (cid:62) x i ≥ (cid:17) , (57)where c is a constant. Similarly, we can simplify equation 55 as β (cid:48) w = c · n (cid:88) i =1 λ i · w (cid:62) x i · I (cid:16) w (cid:62) x i ≥ (cid:17) (58)Observe that combining equation 57 and equation 58 implies that the constraint equation 58 can be furthersimplified as β (cid:48) w = w (cid:62) β w (59)It remains to show that given the condition on training data, there exists a set of λ i so that equation 57and equation 59 are satisfied. 24 lobal optimum via the geometry of training data. Recall that we assume our training data { ( x i , y i ) } ni =1 satisfies for any w ∈ R d , there exist d linearly independent { x w i } di =1 ⊂ X , where X = { x i } ni =1 , so that w (cid:62) x w i ≥ and − x w i ∈ X for i = 1 ..d , e.g., an orthogonal basis of R d and their opposite vectors. We willshow that under this data regime, we have (a) for any particular w , there indeed exist a set of λ i that can satisfy the constraints equation 57 andequation 59 for this particular w . (b) For any w and w that activate the exact same set of { x i } , the same set of λ i can satisfy theconstraints equation 57 and equation 59 of both w and w . (c) Whenever we rotate a w to a w so that the set of x i being activated changed, we can still find λ i that satisfy constraint of both w and w .Combining (a), (b) and (c) implies there exists a set of λ that satisfy the constraints for all w . Hence, itremains to show these three claims.We first prove Claim (a). For each w , we must find a set of λ i so that the following hold. β ( k ) w = c · n (cid:88) i =1 λ i · x ( k ) i · I (cid:16) w (cid:62) x i ≥ (cid:17) , β (cid:48) w = w (cid:62) β w β w + β (cid:48) w · w = 2 β g Here, β g and w are fixed, and w is a vector on the unit sphere. It is easy to see that β w is then determined by β g and w , and there indeed exists a solution (solving a consistent linear system). Hence we are left with alinear system with d linear equations β ( k ) w = c · n (cid:88) i =1 λ i · x ( k ) i · I (cid:16) w (cid:62) x i ≥ (cid:17) ∀ k ∈ [ d ] to solve with free variables being λ i so that w activates x i , i.e., w (cid:62) x i ≥ . Because the training data { ( x i , y i ) } ni =1 satisfies for any w , there exist at least d linearly independent x i that activate w . This guaranteesfor any w we must have at least d free variables. It follows that there must exist solutions λ i to the linearsystem. This proves Claim (a).Next, we show that (b) for any w and w that activate the exact same set of { x i } , the same set of λ i cansatisfy the constraints equation 57 and equation 59 of both w and w . Because w and w are activated bythe same set of x i , this implies β w = c · n (cid:88) i =1 λ i · x i · I (cid:16) w (cid:62) x i ≥ (cid:17) = c · n (cid:88) i =1 λ i · x i · I (cid:16) w (cid:62) x i ≥ (cid:17) = β w Since λ i already satisfy constraint equation 57 for w , they also satisfy that for w . Thus, it remains to showthat β w + β (cid:48) w · w = β w + β (cid:48) w · w assuming β w = β w , β (cid:48) w = w (cid:62) β w , and β (cid:48) w = w (cid:62) β w . Thisindeed holds because β w + β (cid:48) w · w = β w + β (cid:48) w · w ⇐⇒ β (cid:48) w · w (cid:62) = β (cid:48) w · w (cid:62) ⇐⇒ w (cid:62) β w w (cid:62) = w (cid:62) β w w (cid:62) ⇐⇒ w (cid:62) w β (cid:62) w = w (cid:62) w β (cid:62) w ⇐⇒ · β (cid:62) w = 1 · β (cid:62) w ⇐⇒ β w = β w w and w are vectors on the unit sphere. This proves Claim (b).Finally, we show (c) that Whenever we rotate a w to a w so that the set of x i being activated changed,we can still find λ i that satisfy constraint of both w and w . Suppose we rotate w to w so that w lostactivation with x , x , ..., x p which in the set of linearly independent x i ’s being activated by w and theiropposite vectors − x i are also in the training set (without loss of generality). Then w must now also getactivated by − x , − x , ..., − x p . This is because if w (cid:62) x i < , we must have w (cid:62) ( − x i ) > .Recall that in the proof of Claim (a), we only needed the λ i from linearly independent x i that we used tosolve the linear systems, and their opposite as the free variables to solve the linear system of d equations.Hence, we can set λ to for the other x i while still satisfying the linear system. Then, suppose there exists λ i that satisfy β ( k ) w = c · d (cid:88) i =1 λ i · x ( k ) i where the x i are the linearly independent vectors that activate w with opposite vectors in the training set,which we have proved in (a). Then we can satisfy the constraint for β w below β ( k ) w = c · p (cid:88) i =1 ˆ λ i · ( − x i ) ( k ) + d (cid:88) i = p +1 λ i · x ( k ) i by setting ˆ λ i = − λ i for i = 1 ...p . Indeed, this gives β ( k ) w = c · p (cid:88) i =1 ( − λ i ) · ( − x i ) ( k ) + d (cid:88) i = p +1 λ i · x ( k ) i = c · d (cid:88) i =1 λ i · x ( k ) i Thus, we can also find λ i that satisfy the constraint for β w . Here, we do not consider the case where w is parallel with an x i because such w has measure zero. Note that we can apply this argument iterativelybecause the flipping the sign always works and will not create any inconsistency.Moreover, we can show that the constraint for β (cid:48) w is satisfied by a similar argument as in proof of Claim(b). This follows from the fact that our construction makes β w = β w . Then we can follow the sameargument as in (b) to show that β w + β (cid:48) w · w = β w + β (cid:48) w · w . This completes the proof of Claim (c).In summary, combining Claim (a), (b) and (c) gives that Claim 12 holds. That is, given our trainingdata, the global optimum to the constrained optimization problem of finding the min-norm solution amongfunctions that fit the training data satisfies β w + β (cid:48) w · w = 2 β g . We also showed that this claim implies exactextrapolation, i.e., the network’s learned function f ( x ) is equal to the true underlying function g ( x ) for all x ∈ R d . This completes the proof. B.3 Proof of Theorem 5
Proof of the asymptotic convergence to extrapolation builds upon our proof of exact extrapolation, i.e.,Lemma 4. The proof idea is that if the training data distribution has support at all directions, when the numberof samples n → ∞ , asymptotically the training set will converge to some imaginary training set that satisfiesthe condition for exact extrapolation. Since if training data are close the neural tangent kernels are also close,the predictions or learned function will converge to a function that achieves perfect extrapolation, that is, thetrue underlying function. 26 symptotic convergence of data sets. We first show the training data converge to a data set that satisfiesthe exact extrapolation condition in Lemma 4. Suppose training data { x i } ni =1 are sampled from a distributionwhose support contains a connected set S that intersects all directions, i.e., for any non-zero w ∈ R d , thereexists k > so that k w ∈ S .Let us denote by S the set of datasets that satisfy the exact condition in Lemma 4. Given a general dataset X and a dataset S ∈ S of the same size n , let σ ( X , S ) denote a matching of their data points, i.e., σ outputsa sequence of pairs σ ( X , S ) i = ( x i , s i ) for i ∈ [ n ] s.t. X = { x i } ni =1 S = { s i } ni =1 Let (cid:96) : R d × R d → R be the l distance that takes in a pair of points. We then define the distance betweenthe datasets d ( X , S ) as the minimum sum of l distances of their data points over all possible matching. d ( X , S ) = min σ n (cid:80) i =1 (cid:96) ( σ ( X , S ) i ) | X | = | S | = n ∞ | X | (cid:54) = | S | We can then define a “closest distance to perfect dataset” function D ∗ : X → R which maps a dataset X to the minimum distance of X to any dataset in SD ∗ ( X ) = min S ∈S d ( X , S ) It is easy to see that for any dataset X = { x i } ni =1 , D ∗ ( X ) can be bounded by the minimum of theclosest distance to perfect dataset D ∗ of sub-datasets of X of size d . D ∗ ( { x i } ni =1 ) ≤ (cid:98) n/ d (cid:99) min k =1 D ∗ (cid:16) { x j } k ∗ dj =( k − ∗ d +1 (cid:17) (60)This is because for any S ∈ S , and any S ⊆ S (cid:48) , we must have S (cid:48) ∈ S because a dataset satisfies exactextrapolation condition as long as it contains some key points. Thus, adding more data will not hurt, i.e., forany X ⊆ X , we always have D ∗ ( X ) ≤ D ∗ ( X ) Now let us denote by X n a random dataset of size n where each x i ∈ X n is sampled from the trainingdistribution. Recall that our training data { x i } ni =1 are sampled from a distribution whose support containsa connected set S ∗ that intersects all directions, i.e., for any non-zero w ∈ R d , there exists k > so that k w ∈ S ∗ . It follows that for a random dataset X d of size d , the probability that D ∗ ( X d ) > (cid:15) happens isless than for any (cid:15) > .First there must exist S = { s i } di =1 ∈ S of size d , e.g., orthogonal basis and their opposite vectors.Observe that if we scale any s i by k > , the resulting dataset is still in S by the definition of S . We denote27he set of datasets where we are allowed to scale elements of S by S . It follows that P ( D ∗ ( X d ) > (cid:15) ) = P (cid:18) min S ∈S d ( X d , S ) > (cid:15) (cid:19) ≤ P (cid:18) min S ∈S d ( X d , S ) > (cid:15) (cid:19) = P (cid:32) min S ∈S min σ n (cid:88) i =1 (cid:96) ( σ ( X d , S ) i ) > (cid:15) (cid:33) = 1 − P (cid:32) min S ∈S min σ n (cid:88) i =1 (cid:96) ( σ ( X d , S ) i ) ≤ (cid:15) (cid:33) ≤ − P (cid:18) min S ∈S min σ n max i =1 (cid:96) ( σ ( X d , S ) i ) ≤ (cid:15) (cid:19) ≤ δ < where we denote the bound of P ( D ∗ ( X d ) > (cid:15) ) by δ < , and the last step follows from P (cid:18) min S ∈S min σ n max i =1 (cid:96) ( σ ( X d , S ) i ) ≤ (cid:15) (cid:19) > which further follows from the fact that for any s i ∈ S , by the assumption on training distribution, we canalways find k > so that k s i ∈ S ∗ , a connected set in the support of training distribution. By the connectivityof support S ∗ , k s i cannot be an isolated point in S ∗ , so for any (cid:15) > , we must have (cid:90) (cid:107) x − k s i (cid:107)≤ (cid:15), x ∈S ∗ f X ( x ) d x > Hence, we can now apply equation 60 to bound D ∗ ( X n ) . Given any (cid:15) > , we have P ( D ∗ ( X n ) > (cid:15) ) = 1 − P ( D ∗ ( X n ) ≤ (cid:15) ) ≤ − P (cid:32) (cid:98) n/ d (cid:99) min k =1 D ∗ (cid:16) { x j } k ∗ dj =( k − ∗ d +1 (cid:17) ≤ (cid:15) (cid:33) ≤ − − (cid:98) n/ d (cid:99) (cid:89) k =1 P (cid:16) D ∗ (cid:16) { x j } k ∗ dj =( k − ∗ d +1 (cid:17) > (cid:15) (cid:17) = (cid:98) n/ d (cid:99) (cid:89) k =1 P (cid:16) D ∗ (cid:16) { x j } k ∗ dj =( k − ∗ d +1 (cid:17) > (cid:15) (cid:17) ≤ δ (cid:98) n/ d (cid:99) Here δ < . This implies D ∗ ( X n ) p −→ , i.e., lim n →∞ P ( D ∗ ( X n ) > (cid:15) ) = 0 ∀ (cid:15) > (61)equation 61 says as the number of training samples n → ∞ , our training set will converge in probability to adataset that satisfies the requirement for exact extrapolation.28 symptotic convergence of predictions. Let NTK ( x , x (cid:48) ) : R d × R d → R denote the neural tangentkernel for a two-layer ReLU MLP. It is easy to see that if x → x ∗ , then NTK ( x , · ) → NTK ( x ∗ , · ) (Aroraet al. [2019b]). Let NTK train denote the n × n kernel matrix for training data.We have shown that our training set converges to a perfect data set that satisfies conditions of exactextrapolation. Moreover, note that our training set will only have a finite number of (not increase with n ) x i that are not precisely the same as those in a perfect dataset. This is because a perfect data only contains afinite number of key points and the other points can be replaced by any other points while still being a perfectdata set. Thus, we have NTK train → N ∗ , where N ∗ is the n × n NTK matrix for some perfect data set.Because neural tangent kernel is positive definite, we have NTK − train → N ∗ − . Recall that for any x ∈ R d ,the prediction of NTK is f NTK ( x ) = ( NTK ( x , x ) , ..., NTK ( x , x n )) · NTK − train Y , where NTK train is the n × n kernel for training data, NTK ( x , x i ) is the kernel value between test data x andtraining data x i , and Y is training labels.Similarly, we have ( NTK ( x , x ) , ..., NTK ( x , x n )) → ( NTK ( x , x ∗ ) , ..., NTK ( x , x ∗ n )) , where x ∗ i is aperfect data set that our training set converges to. Combining this with NTK − train → N ∗ − gives f NTK p −→ f ∗ NTK = g, where f NTK is the function learned using our training set, and f ∗ NTK is that learned using a perfect data set,which is equal to the true underlying function g . This completes the proof. B.4 Proof of Corollary 8
In order for GNN with linear aggregations h ( k ) u = (cid:88) v ∈N ( u ) MLP ( k ) (cid:16) h ( k ) u , h ( k ) v , x ( u,v ) (cid:17) ,h G = MLP ( K +1) (cid:16) (cid:88) u ∈ G h ( K ) u (cid:17) , to extrapolate in the maximum degree task, it must be able to simulate the underlying function h G = max u ∈ G (cid:88) v ∈N ( u ) Because the max function cannot be decomposed as the composition of piece-wise linear functions, theMLP ( K +1) module in GNN must learn a function that is not piece-wise linear over domains outside thetraining data range. Since Theorem 3 proves for two-layer overparameterized MLPs, here we also assumeMLP ( K +1) is a two-layer overparameterized MLP, although the result can be extended to more layers. It thenfollows from Theorem 3 that for any input and label (and thus gradient), MLP ( K +1) will converge to linearfunctions along directions from the origin. Hence, there are always domains where the GNN cannot learn acorrect target function. B.5 Proof of Theorem 9
Our proof applies the similar proof techniques for Lemma 4 and 5 to Graph Neural Networks (GNNs). Thisis essentially an analysis of Graph Neural Tangent Kernel (GNTK), i.e., neural tangent kernel of GNNs.29e first define the simple GNN architecture we will be analyzing, and then present the GNTK for thisarchitecture. Suppose G = ( V, E ) is an input graph without edge feature, and x u ∈ R d is the node feature ofany node u ∈ V . Let us consider the simple one-layer GNN whose input is G and output is h G h G = W (2) max u ∈ G (cid:88) v ∈N ( u ) W (1) x v (62)Note that our analysis can be extended to other variants of GNNs, e.g., with non-empty edge features, ReLUactivation, different neighbor aggregation and graph-level pooling architectures. We analyze this GNN forsimplicity of exposition.Next, let us calculate the feature map of the neural tangent kernel for this GNN. Recall from Section Athat consider a graph neural network f ( θ , G ) : G → R where θ ∈ R m is the parameters in the network and G ∈ G is the input graph. Then the neural tangent kernel is H ij = (cid:28) ∂f ( θ , G i ) ∂ θ , ∂f ( θ , G j ) ∂ θ (cid:29) , where θ are the infinite-dimensional parameters. Hence, the gradients with respect to all parameters give anatural feature map. Let us denote, for any node u , the degree of u by h u = (cid:88) v ∈N ( u ) x v (63)It then follows from simple computation of derivative that the following is a feature map of the GNTK forequation 62 φ ( G ) = c · (cid:32) max u ∈ G (cid:16) w ( k ) (cid:62) h u (cid:17) , (cid:88) u ∈ G I (cid:18) u = arg max v ∈ G w ( k ) (cid:62) h v (cid:19) · h u , ... (cid:33) , (64)where w ( k ) ∼ N ( , I ) , with k going to infinity. c is a constant, and I is the indicator function.Next, given training data { ( G i , y i } ni =1 , let us analyze the function learned by GNN through the min-normsolution in the GNTK feature space. The same proof technique is also used in Lemma 4 and 5.Recall the assumption that all graphs have uniform node feature, i.e., the learning task only considersgraph structure, but not node feature. We assume x v = 1 without loss of generality. Observe that in thiscase, there are two directions, positive or negative, for one-dimensional Gaussian distribution. Hence, we cansimplify our analysis by combining the effect of linear coefficients for w in the same direction as in Lemma 4and 5.Similarly, for any w , let us define ˆ β w ∈ R as the linear coefficient corresponding to (cid:80) u ∈ G I (cid:18) u = arg max v ∈ G w (cid:62) h v (cid:19) · h u in RKHS space, and denote by ˆ β (cid:48) w ∈ R the weight for max u ∈ G (cid:0) w (cid:62) h u (cid:1) . Similarly, we can combine theeffect of all ˆ β in the same direction as in Lemma 4 and 5. We define the combined effect with β w and β (cid:48) w .This allows us to reason about w with two directions, + and − .Recall that the underlying reasoning function, maximum degree, is g ( G ) = max u ∈ G h u . We formulate the constrained optimization problem, i.e., min-norm solution in GNTK feature space thatfits all training data, as min ˆ β , ˆ β (cid:48) (cid:90) ˆ β w + ˆ β (cid:48) w d P ( w ) s.t. (cid:90) (cid:88) u ∈ G i I (cid:18) u = arg max v ∈ G w · h v (cid:19) · ˆ β w · h u + max u ∈ G i ( w · h u ) · ˆ β (cid:48) w d P ( w ) = max u ∈ G i h u ∀ i ∈ [ n ] , G i is the i-th training graph and w ∼ N (0 , . By combining the effect of ˆ β , and taking the derivativeof the Lagrange for the constrained optimization problem and setting to zero, we get the global optimumsolution satisfy the following constraints. β + = c · n (cid:88) i =1 λ i · (cid:88) u ∈ G i h u · I (cid:18) u = arg max v ∈ G i h v (cid:19) (65) β − = c · n (cid:88) i =1 λ i · (cid:88) u ∈ G i h u · I (cid:18) u = arg min v ∈ G i h v (cid:19) (66) β (cid:48) + = c · n (cid:88) i =1 λ i · max u ∈ G i h u (67) β (cid:48)− = c · n (cid:88) i =1 λ i · min u ∈ G i h u (68) max u ∈ G i h u = β + · (cid:88) u ∈ G i I (cid:18) u = arg max v ∈ G i h v (cid:19) · h u + β (cid:48) + · max u ∈ G i h u (69) + β − · (cid:88) u ∈ G i I (cid:18) u = arg min v ∈ G i h v (cid:19) · h u + β (cid:48)− · min u ∈ G i h u ∀ i ∈ [ n ] (70)where c is some constant, λ i are the Lagrange parameters. Note that here we used the fact that there are twodirections +1 and − . This enables the simplification of Lagrange derivative. For a similar step-by-stepderivation of Lagrange, refer to the proof of Lemma 4.Let us consider the solution β (cid:48) + = 1 and β + = β − = β (cid:48)− = 0 . It is clear that this solution can fit thetraining data, and thus satisfies equation 69. Moreover, this solution is equivalent to the underlying reasoningfunction, maximum degree, g ( G ) = max u ∈ G h u .Hence, it remains to show that, given our training data, there exist λ i so that the remaining four constraintsare satisfies for this solution. Let us rewrite these constraints as a linear systems where the variables are λ i β + β − β (cid:48) + β (cid:48)− = c · n (cid:88) i =1 λ i · (cid:80) u ∈ G i h u · I (cid:18) u = arg max v ∈ G i h v (cid:19)(cid:80) u ∈ G i h u · I (cid:18) u = arg min v ∈ G i h v (cid:19) max u ∈ G i h u min u ∈ G i h u (71)By standard theory of linear systems, there exist λ i to solve equation 71 if there are at least four trainingdata G i whose following vectors linear independent (cid:80) u ∈ G i h u · I (cid:18) u = arg max v ∈ G i h v (cid:19)(cid:80) u ∈ G i h u · I (cid:18) u = arg min v ∈ G i h v (cid:19) max u ∈ G i h u min u ∈ G i h u = max u ∈ G i h u · N max i min u ∈ G i h u · N min i max u ∈ G i h u min u ∈ G i h u (72)Here, N max i denotes the number of nodes that achieve the maximum degree in the graph G i , and N min i denotes the number of nodes that achieve the min degree in the graph G i . By the assumption of our training31ata that there are at least four G i ∼ G with linearly independent equation 72. Hence, our simple GNN learnsthe underlying function as desired.This completes the proof. B.6 Proof of Lemma 10
Let W denote the span of the feature maps of training data x i , i.e. W = span ( φ ( x ) , φ ( x ) , ..., φ ( x n )) . Then we can decompose the coordinates of f NTK in the RKHS space, β NTK , into a vector β for thecomponent of f NTK in the span of training data features W , and a vector β for the component in theorthogonal complement W (cid:62) , i.e., β NTK = β + β . First, note that since f NTK must be able to fit the training data (NTK is a universal kernel as we will discussnext), i.e., φ ( x i ) (cid:62) β NTK = y i . Thus, we have φ ( x i ) (cid:62) β = y i . Then, β is uniquely determined by the kernel regression solution withrespect to the neural tangent kernel f NTK ( x ) = (cid:0)(cid:10) φ ( x ) , φ ( x ) (cid:11) , ..., (cid:10) φ ( x ) , φ ( x n ) (cid:11)(cid:1) · NTK − train Y , where NTK train is the n × n kernel for training data, (cid:10) φ ( x ) , φ ( x i ) (cid:11) is the kernel between test data x andtraining data x i , and Y is training labels.The kernel regression solution f NTK is uniquely determined because the neural tangent kernel NTK train ispositive definite assuming no two training data are parallel, which can be enforced with a bias term [Du et al.,2019c]. In any case, the solution is a min-norm by pseudo-inverse.Moreover, a unique kernel regression solution f NTK that spans the training data features corresponds to aunique representation in the RKHS space β .Since β and β are orthogonal, we also have the following (cid:107) β NTK (cid:107) = (cid:107) β + β (cid:107) = (cid:107) β (cid:107) + (cid:107) β (cid:107) . This implies the norm of β NTK is at least as large as the norm of any β such that φ ( x i ) (cid:62) β NTK = y i . Moreover,observe that the solution to kernel regression equation 7 is in the feature span of training data, given thekernel matrix for training data is full rank. f NTK ( x ) = (cid:0)(cid:10) φ ( x ) , φ ( x ) (cid:11) , ..., (cid:10) φ ( x ) , φ ( x n ) (cid:11)(cid:1) · NTK − train Y . Since β is for the component of f NTK in the orthogonal complement of training data feature span, we musthave β = . It follows that β NTK is equivalent to min β (cid:107) β (cid:107) s.t. φ ( x i ) (cid:62) β = y i , for i = 1 , ..., n. as desired. 32 .7 Proof of Lemma 11 We first compute the neural tangent kernel NTK ( x , x (cid:48) ) for a two-layer multi-layer perceptron (MLP) withReLU activation function, and then show that it can be induced by the feature space φ ( x ) specified in thelemma so that NTK ( x , x (cid:48) ) = (cid:10) φ ( x ) , φ ( x (cid:48) ) (cid:11) .Recall that Jacot et al. [2018] have derived the general framework for computing the neural tangentkernel of a neural network with general architecture and activation function. This framework is alsodescribed in Arora et al. [2019b], Du et al. [2019b], which, in addition, compute the exact kernel formula forconvolutional networks and Graph Neural Networks, respectively. Following the framework in Jacot et al.[2018] and substituting the general activation function σ with ReLU gives the kernel formula for a two-layerMLP with ReLU activation. This has also been described in several previous works [Du et al., 2019c, Chizatet al., 2019, Bietti and Mairal, 2019].Below we describe the general framework in Jacot et al. [2018] and Arora et al. [2019b]. Let σ denotethe activation function. The neural tangent kernel for an h -layer multi-layer perceptron can be recursivelydefined via a dynamic programming process. Here, Σ ( i ) : R d × R d → R for i = 0 ...h is the covariance forthe i -th layer. Σ (0) ( x , x (cid:48) ) = x (cid:62) x (cid:48) , ∧ ( i ) ( x , x (cid:48) ) = (cid:18) Σ ( i − ( x , x ) Σ ( i − ( x , x (cid:48) )Σ ( i − ( x (cid:48) , x ) Σ ( i − ( x (cid:48) , x (cid:48) ) (cid:19) , Σ ( i ) ( x , x (cid:48) ) = c · E u,v ∼N ( , ∧ ( i ) ) [ σ ( u ) σ ( v )] . The derivative covariance is defined similarly: ˙Σ ( i ) ( x , x (cid:48) ) = c · E u,v ∼N ( , ∧ ( i ) ) [ ˙ σ ( u ) ˙ σ ( v )] . Then the neural tangent kernel for an h -layer network is defined asNTK ( h − ( x , x (cid:48) ) = h (cid:88) i =1 (cid:32) Σ ( i − ( x , x (cid:48) ) · h (cid:89) k = i ˙Σ ( k ) ( x , x (cid:48) ) (cid:33) , where we let ˙Σ ( h ) ( x , x (cid:48) ) = 1 for the convenience of notations.We compute the explict NTK formula for a two-layer MLP with ReLU activation function by followingthis framework and substituting the general activation function with ReLU, i.e. σ ( a ) = max(0 , a ) = a · I ( a ≥ and ˙ σ ( a ) = I ( a ≥ .NTK (1) ( x , x (cid:48) ) = (cid:88) i =1 (cid:32) Σ ( i − ( x , x (cid:48) ) · h (cid:89) k = i ˙Σ ( k ) ( x , x (cid:48) ) (cid:33) = Σ (0) ( x , x (cid:48) ) · ˙Σ (1) ( x , x (cid:48) ) + Σ (1) ( x , x (cid:48) ) So we can get the NTK via Σ (1) ( x , x (cid:48) ) and ˙Σ (1) ( x , x (cid:48) ) , Σ (0) ( x , x (cid:48) ) . Precisely, Σ (0) ( x , x (cid:48) ) = x (cid:62) x (cid:48) , ∧ (1) ( x , x (cid:48) ) = (cid:18) x (cid:62) x x (cid:62) x (cid:48) x (cid:48) (cid:62) x x (cid:48) (cid:62) x (cid:48) (cid:19) = (cid:18) xx (cid:48) (cid:19) · (cid:0) x x (cid:48) (cid:1) , Σ (1) ( x , x (cid:48) ) = c · E u,v ∼N ( , ∧ (1) ) [ u · I ( u ≥ · v · I ( v ≥ .
33o sample from N ( , ∧ (1) ) , we let L be a decomposition of ∧ (1) , such that ∧ (1) = LL (cid:62) . Here, we can seethat L = ( x , x (cid:48) ) (cid:62) . Thus, sampling from N ( , ∧ (1) ) is equivalent to first sampling w ∼ N ( , I ) , and output L w = w (cid:62) ( x , x (cid:48) ) . Then we have the equivalent sampling ( u, v ) = ( w (cid:62) x , w (cid:62) x (cid:48) ) . It follows that Σ (1) ( x , x (cid:48) ) = c · E w ∼N ( , I ) (cid:104) w (cid:62) x · I (cid:16) w (cid:62) x ≥ (cid:17) · w (cid:62) x (cid:48) · I (cid:16) w (cid:62) x (cid:48) ≥ (cid:17)(cid:105) It follows from the same reasoning that ˙Σ (1) ( x , x (cid:48) ) = c · E w ∼N ( , I ) (cid:104) I (cid:16) w (cid:62) x ≥ (cid:17) · I (cid:16) w (cid:62) x (cid:48) ≥ (cid:17)(cid:105) . The neural tangent kernel for a two-layer MLP with ReLU activation is thenNTK (1) ( x , x (cid:48) ) = Σ (0) ( x , x (cid:48) ) · ˙Σ (1) ( x , x (cid:48) ) + Σ (1) ( x , x (cid:48) )= c · E w ∼N ( , I ) (cid:104) x (cid:62) x (cid:48) · I (cid:16) w (cid:62) x ≥ (cid:17) · I (cid:16) w (cid:62) x (cid:48) ≥ (cid:17)(cid:105) + c · E w ∼N ( , I ) (cid:104) w (cid:62) x · I (cid:16) w (cid:62) x ≥ (cid:17) · w (cid:62) x (cid:48) · I (cid:16) w (cid:62) x (cid:48) ≥ (cid:17)(cid:105) . Next, we use the kernel formula to compute a feature map for a two-layer MLP with ReLU activation function.Recall that by definition a valid feature map must satisfy the following conditionNTK (1) ( x , x (cid:48) ) = (cid:10) φ ( x ) , φ ( x (cid:48) ) (cid:11) It is easy to see that the way we represent our NTK formula makes it easy to find such a decomposition. Thefollowing infinite-dimensional feature map would satisfy the requirement because the inner product of φ ( x ) and φ ( x (cid:48) ) for any x , x (cid:48) would be equivalent to the expected value in NTK, after we integrate with respect tothe density function of w . φ ( x ) = c (cid:48) (cid:16) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , w ( k ) (cid:62) x · I (cid:16) w ( k ) (cid:62) x ≥ (cid:17) , ... (cid:17) , where w ( k ) ∼ N ( , I ) , with k going to infinity. c (cid:48) is a constant, and I is the indicator function. Note thathere the density of features of φ ( x ) is determined by the density of w , i.e. Gaussian. C Experimental Details
In this section, we describe the model, data and training details for reproducing our experiments. Ourexperiments support all of our theoretical claims and insights.
Overview.
We classify our experiments into the following major categories, each of which includes severalablation studies:1) Learning tasks where the target functions are simple non-linear functions in various dimensions andtraining/test distributions: quadratic, cosine, square root, and l1 norm functions, with MLPs with awide range of hyper-parameters.This validates our implications on MLPs generally cannot extrapolate in tasks with non-linear targetfunctions, unless the non-linear function is directionally linear out-of-distribution. In the latter case,the extrapolation error is more sensitive to the hyper-parameters.34) Computation of the
R-Squared of MLP’s learned functions along (thousands of) randomly sampleddirections in out-of-distribution domain.This validates Theorem 3 and shows the convergence rate is very high in practice, and often happensimmediately out of training range.3) Learning tasks where the target functions are linear functions with MLPs. These validate Theorem 5and 4, i.e., MLPs can extrapolate if the underlying function is linear under conditions on trainingdistribution. This section includes four ablation studies:a) Training distribution satisfy the conditions in Theorem 5 and cover all directions, and hence,MLPs extrapolate.b) Training data distribution is restricted in some directions , e.g., restricted to be positive/negative/constantin some feature dimensions. This shows when training distribution is restrictive in directions,MLPs may fail to extrapolate.c) Exact extrapolation with infinitely-wide neural networks , i.e., exact computation with neu-ral tangent kernel (NTK) on the data regime in Theorem 4. This is mainly for theoreticalunderstanding.4) MLPs with cosine, quadratic, and tanh activation functions.5) Summary statistics: learning maximum degree of graphs with Graph Neural Networks. Extrapo-lation on graph structure, number of nodes, and node features. To show the role of architecture forextrapolation, we study the following GNN architecture regimes.a) GNN with graph-level max-pooling and neighbor-level sum-pooling. By Theorem 9, this GNNarchitecture extrapolates in max degree with appropriate training data.b) GNN with graph-level and neighbor-level sum-pooling. By Corollary 8, this default GNNarchitecture cannot extrapolate in max degree.To show the importance of training distribution, i.e., graph structure in training set, we study thefollowing training data regimes.a) Node features are identical , e.g., . In such regimes, our learning tasks only consider graphstructure. We consider training sets sampled from various graph structure, and find only thosesatisfy conditions in Theorem 9 enables GNNs with graph-level max-pooling to extrapolate.b) Node features are spurious and continuous. This also requires extrapolation on OOD nodefeatures. GNNs with graph-level max-pooling with appropriate training sets also extrapolate toOOD spurious node features.6) Dynamic programming: learning the length of the shortest path between given source and targetnodes, with Graph Neural Networks. Extrapolation on graph structure, number of nodes, and edgeweights. We study the following regimes.a) Continuous features. Edge and node features are real values. This regime requires extrapolatingto graphs with edge weights out of training range.Test graphs are all sampled from the “general graphs” family with a diverse range of structure.Regarding the type of training graph structure, we consider two schemes. Both schemes show aU-shape curve of extrapolation error with respect to the sparsity of training graphs.35) Specific graph structure: path, cycle, tree, expander, ladder, complete graphs, general graphs,4-regular graphs.b) Random graphs with a range of probability p of an edge between any two nodes. Smaller p samples sparse graphs and large p samples dense graphs.7) Dynamic programming: Physical reasoning of the n -Body problem in the orbit setting with GraphNeural Networks. We show that GNNs on the original features from previous works fail to extrapolateto unseen masses and distances. On the other hand, we show extrapolation can be achieved via animproved representation of the input edge features. We consider the following extrapolation regimes.a) Extrapolation on the masses of the objects.b) Extrapolation on the distances between objects.We consider the following two input representation schemes to compare the effects of how representa-tion helps extrapolation.a) Original features. Following previous works on solving n -body problem with GNNs, the edgefeatures are simply set to .b) Improved features. We show although our edge features do not bring in new information, it helpsextrapolation. C.1 Learning Simple Non-Linear Functions
Dataset details.
We consider four tasks where the underlying functions are simple non-linear functions g : R d → R . Given an input x ∈ R d , the label is computed by y = g ( x ) for all x . We consider the followingfour families of simple functions g .a) Quadratic functions g ( x ) = x (cid:62) A x . In each dataset, we randomly sample A . In the simplest casewhere A = I , g ( x ) = (cid:80) di =1 x i .a) Cosine functions g ( x ) = (cid:80) di =1 cos (2 π · x i ) .c) Square root functions g ( x ) = (cid:80) di =1 √ x i . Here, the domain X of x is restricted to the space in R d with non-negative value in each dimension.d) L1 norm functions g ( x ) = | x | = (cid:80) di =1 | x i | .We sample each dataset of a task by considering the following parametersa) The shape and support of training, validation, and test data distributions.i) Training, validation, and test data are uniformly sampled from a hyper-cube. Training andvalidation data are sampled from [ − a, a ] d with a ∈ { . , . } , i.e., each dimension of x ∈ R d isuniformly sampled from [ − a, a ] . Test data are sampled from [ − a, a ] d with a ∈ { . , . , . } .ii) Training and validation data are uniformly sampled from a sphere, where every point has L dis-tance r from the origin. We sample r from r ∈ { . , . } . Then, we sample a random Gaussianvector q in R d . We obtain the training or validation data x = q / (cid:107) q (cid:107) · r . This corresponds touniform sampling from the sphere.Test data are sampled (non-uniformly) from a hyper-ball. We first sample r uniformly from [0 . , . , [0 . , . , and [0 . , . . Then, we sample a random Gaussian vector q in R d . We36btain the test data x = q / (cid:107) q (cid:107) · r . This corresponds to (non-uniform) sampling from a hyper-ballin R d .b) We sample , training data, , validation data, and , test data.c) We sample input dimension d from { , , } .d) For quadratic functions, we sample the entries of A uniformly from [ − , . Model and hyperparameter settings.
We consider the multi-layer perceptron (MLP) architecture.MLP ( x ) = W ( d ) · σ (cid:16) W ( d − σ (cid:16) ...σ (cid:16) W (1) x (cid:17)(cid:17)(cid:17) We search the following hyper-parameters for MLPsa) Number of layers d from { , } .b) Width of each W ( k ) from { , , } .c) Initialization schemes.i) The default initialization in PyTorch.ii) The initialization scheme in neural tangent kernel theory, i.e., we sample entries of W k from N (0 , and scale the output after each W ( k ) by (cid:113) d k , where d k is the output dimension of W ( k ) .d) Activation function σ is set to ReLU.We train the MLP with the mean squared error (MSE) loss, and Adam and SGD optimizer. We considerthe following hyper-parameters for traininga) Initial learning rate from { e − , e − , e − , e − } . Learning rate decays . for every epochsb) Batch size from { , , } .c) Weight decay is set to e − .d) Number of epochs is set to . Test error and model selection.
For each dataset and architecture, training hyper-parameter setting, weperform model selection via validation set, i.e., we report the test error by selecting the epoch where themodel achieves the best validation error. Note that our validation sets always have the same distribution asthe training sets.We train our models with the MSE loss. Because we sample test data from different ranges, the meanabsolute percentage error (MAPE) loss, which scales the error by the actual value, better measures theextrapolation performance MAPE = 1 n (cid:12)(cid:12)(cid:12)(cid:12) A i − F i A i (cid:12)(cid:12)(cid:12)(cid:12) , where A i is the actual value and F i is the predicted value. Hence, in our experiments, we also report theMAPE. 37 .2 R-squared for Out-of-distribution Directions We perform linear regression to fit the predictions of MLPs along randomly sampled directions in out-of-distribution regions, and compute the R-squared (or R ) for these directions. This experiment is to validateTheorem 3 and show that the convergence rate (to a linear function) is very high in practice. Definition.
R-squared, also known as coefficient of determination, assesses how strong the linear relation-ship is between input and output variables. The closer R-squared is to , the stronger the linear relationshipis, with being perfectly linear. Datasets and models.
We perform the R-squared computation on over , combinations of datasets,test/train distributions, and hyper-parameters, e.g., learning rate, batch size, MLP layer, width, initialization.These are described in Appendix C.1. Computation.
For each combination of dataset and model hyper-parameters as described in Section C.1,we save the trained MLP model f : R d → R . For each dataset and model combination, we then randomlysample , directions via Gaussian vectors N ( , I ) . For each of these directions w , we compute theintersection point x w of direction w and the training data distribution support (specified by a hyper-sphere orhyper-cube; see Section C.1 for details).We then collect predictions of the trained MLP f along direction w (assume w is normalized) with (cid:110)(cid:16) x w + k · r · w (cid:17) , f (cid:16) x w + k · r · w (cid:17)(cid:111) k =0 , (73)where r is the range of training data distribution support (see Section C.1). We perform linear regression onthese predictions in equation 73, and obtain the R-squared. Results.
We obtain the R-squared for each combination of dataset, model and training setting, and randomlysampled direction. For the tasks of learning the simple non-linear functions, we confirm that more than of the R-squared results are above . . This empirically confirms Theorem 3 and shows that the convergencerate is in fact fast in practice. Along most directions, MLP’s learned function becomes linear immediately outof the training data support. C.3 Learning Linear Functions
Dataset details.
We consider the tasks where the underlying functions are linear g : R d → R . Given aninput x ∈ R d , the label is computed by y = g ( x ) = A x for all x . For each dataset, we sample the followingparametersa) We sample , training data, , validation data, and , test data.b) We sample input dimension d from { , , } .c) We sample entries of A uniformly from [ − a, a ] , where we sample a ∈ { . , . } .d) The shape and support of training, validation, and test data distributions.i) Training, validation, and test data are uniformly sampled from a hyper-cube. Training andvalidation data are sampled from [ − a, a ] d with a ∈ { . , . } , i.e., each dimension of x ∈ R d is uniformly sampled from [ − a, a ] . Test data are sampled from [ − a, a ] d with a ∈ { . , . } .ii) Training and validation data are uniformly sampled from a sphere, where every point has L dis-tance r from the origin. We sample r from r ∈ { . , . } . Then, we sample a random Gaussianvector q in R d . We obtain the training or validation data x = q / (cid:107) q (cid:107) · r . This corresponds to38niform sampling from the sphere.Test data are sampled (non-uniformly) from a hyper-ball. We first sample r uniformly from [0 . , . and [0 . , . , . Then, we sample a random Gaussian vector q in R d . We obtain thetest data x = q / (cid:107) q (cid:107) · r . This corresponds to (non-uniform) sampling from a hyper-ball in R d .e) We perform ablation study on how the training distribution support misses directions. The testdistributions remain the same as in d).i) We restrict the first dimension of any training data x i to a fixed number . , and randomly samplethe remaining dimensions according to d).ii) We restrict the first k dimensions of any training data x i to be positive. For input dimension ,we only consider the hyper-cube training distribution, where we sample the first k dimensionsfrom [0 , a ] and sample the remaining dimensions from [ − a, a ] . For input dimensions and , weconsider both hyper-cube and hyper-sphere training distribution by performing rejection sampling.For input dimension , we consider k from { , } . For input dimension , we consider k from { , , } .iii) We restrict the first k dimensions of any training data x i to be negative. For input dimension ,we only consider the hyper-cube training distribution, where we sample the first k dimensionsfrom [ − a, and sample the remaining dimensions from [ − a, a ] . For input dimensions and , we consider both hyper-cube and hyper-sphere training distribution by performing rejectionsampling. For input dimension , we consider k from { , } . For input dimension , we consider k from { , , } . Model and hyperparameter settings.
For the regression task, we search the same set of hyper-parametersas those in simple non-linear functions (Section C.1).We report the test error with the same validationprocedure as in Section C.1.
Exact computation with neural tangent kernel
Our experiments with MLPs validate Theorem 5 asymp-totic extrapolation for neural networks trained in regular regimes. Here, we also validate Lemma 4, exactextrapolation with finite data regime, by training an infinitely-wide neural network. That is, we directlyperform the kernel regression with the neural tangent kernel (NTK). This experiment is mainly of theoreticalinterest.We sample the same test set as in our experiments with MLPs. For training set, we sample d trainingexamples according to the conditions in Lemma 4. Specifically, we first sample an orthogonal basis and theiropposite vectors X = { e i , − e i } di =1 . We then randomly sample orthogonal transform matrices Q viathe QR decomposition. Our training samples are Q X , i.e., multiply each point in X by Q . This gives training sets with d data points satisfying the condition in Theorem 4.We perform kernel regression on these training sets using a two-layer neural tangent kernel (NTK). Ourcode for exact computation of NTK is adapted from Arora et al. [2020], Novak et al. [2020].We verify that the test losses are all precisely , up to machine precision. This empirically confirmsLemma 4. C.4 MLPs with cosine, quadratic, and tanh Activation
This section describes the experimental settings for extrapolation experiments for MLPs with cosine, quadratic,and tanh activation functions. We train MLPs to learn the following functions:a) Quadratic function g ( x ) = x (cid:62) A x , where A is a randomly sampled matrix.39) Cosine function g ( x ) = (cid:80) di =1 cos(2 π · x i ) .c) Hyperbolic tangent function g ( x ) = (cid:80) di =1 tanh( x i ) .d) Linear function g ( x ) = W x + b . Dataset details.
We use 20,000 training, 1,000 validation, and 20,000 test data. For quadratic, we sampleinput dimension d from { , } , training and validation data from [ − , d , and test data from [ − , d . Forcosine, we sample input dimension d from { , } , training and validation data from [ − , d , and testdata from [ − , d . For tanh, we sample input dimension d from { , } , training and validation data from [ − , d , and test data from [ − , d . For linear, we use a subset of datasets from Appendix C.3: 1and 8 input dimensions with hyper-cube training distributions. Model and hyperparameter settings.
We use the same hyperparameters from Appendix C.1, except wefix the batch size to 128, as the batch size has minimal impact on models. MLPs with cos activation is hard tooptimize, so we only report models with training MAPE less than 1.
C.5 Max Degree
Dataset details.
We consider the task of finding the maximum degree on a graph. Given any input graph G = ( V, E ) , the label is computed by the underlying function y = g ( G ) = max u ∈ G (cid:80) v ∈N ( u ) . For eachdataset, we sample the graphs and node features with the following parametersa) Graph structure for training and validation sets. For each dataset, we consider one of the followinggraph structure: path graphs, cycles, ladder graphs, 4-regular random graphs, complete graphs, randomtrees, expanders (random graphs with p = 0 . ), and general graphs (random graphs with p = 0 . to . with equal probability). We use the networkx library for sampling graphs.b) Graph structure for test set. We consider the general graphs (random graphs with p = 0 . to . withequal probability).c) The number of vertices of graphs | V | for training and validation sets are sampled uniformly from [20 ... . The number of vertices of graphs | V | for test set is sampled uniformly from [50 .. .d) We consider two schemes for node features.i) Identical features. All nodes in training, validation and set sets have uniform feature .ii) Spurious (continuous) features. Node features in training and validation sets are sampled uni-formly from [ − . , . , i.e., a three-dimensional vector where each dimension is sampledfrom [ − . , . . There are two schemes for test sets, in the first case we do not extrapolatenode features, so we sample node features uniformly from [ − . , . . In the second case weextrapolate node features, we sample node features uniformly from [ − . , . .e) We sample , graphs for training, , graphs for validation, and , graphs for testing. Model and hyperparameter settings.
We consider the following Graph Neural Network (GNN) architec-ture. Given an input graph G , GNN learns the output h G by first iteratively aggregating and transforming theneighbors of all node vectors h ( k ) u (vector for node u in layer k ), and perform a max or sum-pooling over allnode features h u to obtain h G . Formally, we have h ( k ) u = (cid:88) v ∈N ( u ) MLP ( k ) (cid:16) h ( k − v , h ( k − u (cid:17) , h G = MLP ( K +1) (cid:16) graph-pooling { h ( K ) u : u ∈ G } (cid:17) . (74)40ere, N ( u ) denotes the neighbors of u , K is the number of GNN iterations, and graph-pooling is a hyper-parameter with choices as max or sum. h (0) u is the input node feature of node u . We search the followinghyper-parameters for GNNsa) Number of GNN iterations K is .b) Graph pooling is from max or sum.c) Width of all MLPs are set to .d) The number of layers for MLP ( k ) with k = 1 ..K are set to . The number of layers for MLP ( K +1) isset to .We train the GNNs with the mean squared error (MSE) loss, and Adam and SGD optimizer. We searchthe following hyper-parameters for traininga) Initial learning rate is set to . .b) Batch size is set to .c) Weight decay is set to e − .d) Number of epochs is set to for graphs with continuous node features, and for graphs withuniform node features. Test error and model selection.
For each dataset and architecture, training hyper-parameter setting, weperform model selection via validation set, i.e., we report the test error by selecting the epoch where themodel achieves the best validation error. Note that our validation sets always have the same distribution asthe training sets. Again, we report the MAPE for test error as in MLPs.
C.6 Shortest Path
Dataset details.
We consider the task of finding the length of the shortest path on a graph, from a givensource to target nodes. Given any graph G = ( V, E ) , the node features, besides regular node features, encodewhether a node is source s , and whether a node is target t . The edge features are a scalar representing theedge weight. For unweighted graphs, all edge weights are . Then the label y = g ( G ) is the length of theshortest path from s to t on G .For each dataset, we sample the graphs and node, edge features with the following parametersa) Graph structure for training and validation sets. For each dataset, we consider one of the followinggraph structure: path graphs, cycles, ladder graphs, 4-regular random graphs, complete graphs, randomtrees, expanders (random graphs with p = 0 . ), and general graphs (random graphs with p = 0 . to . with equal probability). We use the networkx library for sampling graphs.b) Graph structure for test set. We consider the general graphs (random graphs with p = 0 . to . withequal probability).c) The number of vertices of graphs | V | for training and validation sets are sampled uniformly from [20 ... . The number of vertices of graphs | V | for test set is sampled uniformly from [50 .. .d) We consider the following scheme for node and edge features. All edges have continuous weights.Edge weights for training and validation graphs are sampled from [1 . , . . There are two schemes fortest sets, in the first case we do not extrapolate edge weights, so we sample edge weights uniformlyfrom [1 . , . . In the second case we extrapolate edge weights, we sample edge weights uniformlyfrom [1 . , . . All node features (spurious) are sampled from [ − . , . , I ( v = s ) , I ( v = t )] .41) After sampling a graph and edge weights, we sample source s and t by randomly sampling s , t andselecting the first pair s , s whose shortest path involves at most hops. This enables us to solve thetask using GNNs with iterations.f) We sample , graphs for training, , graphs for validation, and , graphs for testing.We also consider the ablation study of training on random graphs with different p . We consider p =0 . .. . and report the test error curve. The other parameters are the same as described above. Model and hyperparameter settings.
We consider the following Graph Neural Network (GNN) architec-ture. Given an input graph G , GNN learns the output h G by first iteratively aggregating and transforming theneighbors of all node vectors h ( k ) u (vector for node u in layer k ), and perform a max or sum-pooling over allnode features h u to obtain h G . Formally, we have h ( k ) u = min v ∈N ( u ) MLP ( k ) (cid:16) h ( k − v , h ( k − u , w ( u,v ) (cid:17) , h G = MLP ( K +1) (cid:18) min u ∈ G h u (cid:19) . (75)Here, N ( u ) denotes the neighbors of u , K is the number of GNN iterations, and for neighbor aggregation werun both min and sum. h (0) u is the input node feature of node u . w ( u,v ) is the input edge feature of edge ( u, v ) .We search the following hyper-parameters for GNNsa) Number of GNN iterations K is set to .b) Graph pooling is set to min.c) Neighobr aggregation is selected from min and sum.d) Width of all MLPs are set to .e) The number of layers for MLP ( k ) with k = 1 ..K are set to . The number of layers for MLP ( K +1) isset to .We train the GNNs with the mean squared error (MSE) loss, and Adam and SGD optimizer. We considerthe following hyper-parameters for traininga) Initial learning rate is set to . .b) Batch size is set to .c) Weight decay is set to e − .d) Number of epochs is set to .We perform the same model selection and validation as in Section C.5. C.7 N-Body Problem
Task description.
The n-body problem asks a neural network to predict how n stars in a physical systemevolves according to physics laws. That is, we train neural networks to predict properties of future states ofeach star in terms of next frames, e.g., . seconds.Mathematically, in an n-body system S = { X i } ni =1 , such as solar systems, all n stars { X i } ni =1 exertdistance and mass-dependent gravitational forces on each other, so there were n ( n − relations or forces in42he system. Suppose X i at time t is at position x ti and has velocity v ti . The overall forces a star X i receivesfrom other stars is determined by physics laws as the following F ti = G · (cid:88) j (cid:54) = i m i × m j (cid:107) x ti − x tj (cid:107) · (cid:0) x tj − x ti (cid:1) , (76)where G is the gravitational constant, and m i is the mass of star X i . Then acceralation a ti is determined bythe net force F ti and the mass of star m i a ti = F ti /m i (77)Suppose the velocity of star X i at time t is v ti . Then assuming the time steps dt , i.e., difference between timeframes, are sufficiently small, the velocity at the next time frame t + 1 can be approximated by v t +1 i = v ti + a ti · dt. (78)Given m i , x ti , and v ti , our task asks the neural network to predict v t +1 i for all stars X i . In our task, weconsider two extrapolation schemesa) The distances between stars (cid:107) x ti − x tj (cid:107) are out-of-distribution for test set, i.e., different samplingranges from the training set.b) The masses of stars m i are out-of-distribution for test set, i.e., different sampling ranges from thetraining set.Here, we use a physics engine that we code in Python to simulate and sample the inputs and labels. Wedescribe the dataset details next. Dataset details.
We first describe the simulation and sampling of our training set. We sample videosof n-body system evolution, each with rollout, i.e., time steps. We consider the orbit situation: thereexists a huge center star and several other stars. We sample the initial states, i.e., position, velocity, masses,acceleration etc according to the following parameters.a) The mass of the center star is kg .b) The masses of other stars are sampled from [0 . , . kg .c) The number of stars is .d) The initial position of the center star is (0 . , . .d) The initial positions x ti of other objects are randomly sampled from all angles, with a distance in [10 . , . m .e) The velocity of the center star is .f) The velocities of other stars are perpendicular to the gravitational force between the center star anditself. The scale is precisely determined by physics laws to ensure the initial state is an orbit system.For each video, after we get the initial states, we continue to rollout the next frames according the physicsengine described above. We perform rejection sampling of the frames to ensure that all pairwise distances ofstars in a frame are at least m . We guarantee that there are , data points in the training set.The validation set has the same sampling and simultation parameters as the training set. We have , data points in the validation set.For test set, we consider two datasets, where we respectively have OOD distances and masses. We have , data points for each dataset. 43) We sample the distance OOD test set to ensure all pairwise distances of stars in a frame are from [1 .. m , but have in-distribution masses.b) We sample the mass OOD test set as followsi) The mass of the center star is kg , i.e., twice of that in the training set.ii) The masses of other stars are sampled from [0 . , . kg , compared to [0 . , . kg in thetraining set.iii) The distances are in-distribution, i.e., same sampling process as training set. Model and hyperparameter settings.
We consider the following one-iteration Graph Neural Network(GNN) architecture, a.k.a. Interaction Networks. Given a collection of stars S = { X i } ni =1 , our GNN runs ona complete graph with nodes being the stars X i . GNN learns the star (node) representations by aggregatingand transforming the interactions (forces) of all other node vectors o u = MLP (2) (cid:88) v ∈ S \{ u } MLP (1) (cid:0) h v , h u , w ( u,v ) (cid:1) . (79)Here, h v is the input feature of node v , including mass, position and velocity h v = ( m v , x v , v v ) w ( u,v ) is the input edge feature of edge ( u, v ) . The loss is computed and backpropagated via the MSE loss of (cid:107) [ o , ..., o n ] − [ ans , .., ans n ] (cid:107) , where o i denotes the output of GNN for node i , and ans i denotes the true label for node i in the next frame.We search the following hyper-parameters for GNNsa) Number of GNN iterations is set to .b) Width of all MLPs are set to .c) The number of layers for MLP (1) is set to . The number of layers for MLP (2) is set to .d) We consider two representations of edge/relations w ( i,j ) .i) The first one is simply .ii) The better representation, which makes the underlying target function more linear, is w ( i,j ) = m j (cid:107) x ti − x tj (cid:107) · (cid:0) x tj − x ti (cid:1) We train the GNN with the mean squared error (MSE) loss, and Adam optimizer. We search the followinghyper-parameters for traininga) Initial learning rate is set to . . learning rate decays . for every epochsb) Batch size is set to .c) Weight decay is set to e − .d) Number of epochs is set to , . 44 Visualization and Additional Experimental Results
D.1 Visualization Results
In this section, we show additional visualization results of the MLP’s learned function out of trainingdistribution (in black color ) v.s. the underlying true function (in grey color ). We color the predictions intraining distribution in blue color .In general, MLP’s learned functions agree with the underlying true functions in training range (blue).This is explained by in-distribution generalization arguments. When out of distribution, the MLP’s learnedfunctions become linear along directions from the origin. We explain this OOD directional linearity behaviorin Theorem 3.Finally, we show additional experimental results for graph-based reasoning tasks.45igure 8: (Quadratic function).
Both panels show the learned v.s. true y = x + x . In each figure, wecolor OOD predictions by MLPs in black, underlying function in grey, and in-distribution predictions in blue.The support of training distribution is a square (cube) for the top panel, and is a circle (sphere) for the bottompanel. 46igure 9: (Cos function). Both panels show the learned v.s. true y = cos(2 π · x ) + cos(2 π · x ) . Ineach figure, we color OOD predictions by MLPs in black, underlying function in grey, and in-distributionpredictions in blue. The support of training distribution is a square (cube) for both top and bottom panels, butwith different ranges. 47igure 10: (Cos function). Top panel shows the learned v.s. true y = cos(2 π · x ) + cos(2 π · x ) wherethe support of training distribution is a circle (sphere). Bottom panel shows results for cosine in 1D, i.e. y = cos(2 π · x ) . In each figure, we color OOD predictions by MLPs in black, underlying function in grey,and in-distribution predictions in blue. 48igure 11: (Sqrt function). Top panel shows the learned v.s. true y = √ x + √ x where the support oftraining distribution is a square (cube). Bottom panel shows the results for the square root function in 1D,i.e. y = √ x . In each figure, we color OOD predictions by MLPs in black, underlying function in grey, andin-distribution predictions in blue. 49igure 12: (L1 function). Both panels show the learned v.s. true y = | x | . In the top panel, the MLPsuccessfully learns to extrapolate the absolute function. In the bottom panel, an MLP with different hyper-parameters fails to extrapolate. In each figure, we color OOD predictions by MLPs in black, underlyingfunction in grey, and in-distribution predictions in blue.50igure 13: (L1 function). Both panels show the learned v.s. true y = | x | + | x | . In the top panel, theMLP successfully learns to extrapolate the l1 norm function. In the bottom panel, an MLP with differenthyper-parameters fails to extrapolate. In each figure, we color OOD predictions by MLPs in black, underlyingfunction in grey, and in-distribution predictions in blue.51igure 14: (Linear function). Both panels show the learned v.s. true y = x + x , with the supportof training distributions being square (cube) for top panel, and circle (sphere) for bottom panel. MLPssuccessfully extrapolate the linear function with both training distributions. This is explained by Theorem 5:both sphere and cube intersect all directions. In each figure, we color OOD predictions by MLPs in black,underlying function in grey, and in-distribution predictions in blue.52 .2 Experimental Results all fix1 pos1 pos16 pos3210 M APE
Figure 15:
Density plot of the test errors in MAPE.
The underlying functions are linear, but we train MLPson different distributions, whose support potentially miss some directions. The training support for “all” arehyper-cubes that intersect all directions. In “fix1”, we set the first dimension of training data to a fixed number.In “posX”, we restrict the first X dimensions of training data to be positive. We can see that MLPs trained on“all” extrapolate the underlying linear functions, but MLPs trained on datasets with missing directions, i.e.,“fix1” and “posX”, often cannot extrapolate well. 53 ath 4-regular ladder cycle expander complete tree general94.2 75.7 86.6 94.2 0.1 0.1 0.3 0.1
Figure 16:
Maximum degree: spurious (real-valued) node features.
Here, each node has a spurious nodefeature in R that shall not contribute to the answer of maximum degree. GNNs with graph-level max-poolingextrapolate to graphs with OOD node features and graph structure, graph sizes, if trained on graphs thatsatisfy the condition in Theorem 9. 54 ath 4-regular ladder cycle expander complete tree general94.5 97.2 12.5 78.3 11.0 86.9 6.4 90.2 0.1 91.1 0.1 107.2 0.0 3099.6 0.0 79.5 Figure 17:
Maximum degree: max-pooling v.s. sum-pooling.
In each sub-figure, left column shows testerrors for GNNs with graph-level max-pooling; right column shows test errors for GNNs with graph-levelsum-pooling. x-axis shows the graph structure covered in training set. GNNs with sum-pooling fail toextrapolate, validating Corollary 8. GNNs with max-pooling encodes appropriate non-linear operations, andthus extrapolates under appropriate training sets (Theorem 9).55 =1.010 M APE
Figure 18:
Shortest path: random graphs.
We train GNNs with neighbor and graph-level min-pooling ontraining sets whose graphs are random graphs with probability p of an edge between any two vertices. x-axisdenotes the p for the training set, and y-axis denotes the test/extrapolation error on unseen graphs. The testerrors follow a U-shape: errors are high if the training graphs are very sparse (small p ) or dense (large pp