Model-Agnostic Meta-Learning using Runge-Kutta Methods
MModel-Agnostic Meta-Learning using Runge-Kutta Methods
Daniel Jiwoong Im ∗ , Yibo Jiang , and Nakul Verma Janelia Research Campus, HHMI, Virgina Harvard University, Massachusetts Columbia University, New York
Abstract
Meta learning has emerged as an important framework for learning new tasks from just afew examples. The success of any meta-learning model depends on (i) its fast adaptation to newtasks, as well as (ii) having a shared representation across similar tasks. Here we extend themodel-agnostic meta-learning (MAML) framework introduced by Finn et al. (2017) to achieveimproved performance by analyzing the temporal dynamics of the optimization procedurevia the Runge-Kutta method. This method enables us to gain fine-grained control over theoptimization and helps us achieve both the adaptation and representation goals across tasks.By leveraging this refined control, we demonstrate that there are multiple principled waysto update MAML and show that the classic MAML optimization is simply a special case ofsecond order Runge-Kutta method that mainly focuses on fast-adaptation. Experiments onbenchmark classification, regression and reinforcement learning tasks show that this refinedcontrol helps attain improved results.
Building an intelligent system that can learn quickly on a new task with few examples or fewexperiences is one of the central goals of machine learning. Achieving this goal requires an agentthat learns continuously while having the ability to adapt to new tasks with limited data. Meta-learning (Biggs, 1985) has emerged as a compelling framework that strives to attain this challenginggoal.There are two main approaches to meta-learning: learning-to-optimize and learning-to-initializethe meta-model (usually encoded as deep network). Learning-to-optimize refers to having a modelthat encodes the learning algorithm and predicts the direction of the parameter updates (Hochreiteret al., 2001). Learning-to-initialize refers to learning a representation that can quickly adapts tosolve multiple tasks (Vinyals et al., 2016; Ravi and Larochelle, 2017; Finn et al., 2017). Herewe focus on understanding and improving the latter approach, which is found a wide range ofapplications (Li et al., 2017; Nichol et al., 2018; Antoniou et al., 2019).As one expects, over-fitting makes learning any new task from a few examples very difficult.Meta-learning overcomes the problem of scarcity by jointly learning from a collection of relatedtasks, each task corresponds to a learning problem on a different (but related) dataset.Model-Agnostic Meta-Learning (MAML) has emerged as a popular state-of-the-art model inthis framework (Finn et al., 2017). It is a gradient based optimization model that learns the meta-parameters (that help to generalize to new tasks) in two update phases: fast adaptation (inner-updates) and meta-updates (outer-updates). Roughly, the inner-updates optimize the parametersto maximize the performance on new tasks using few examples, and the outer-updates optimizethe meta-parameters to find an effective initialization within the few parameter updates.Finding a model that can yield good prediction accuracy within a few updates requires (i)fast-adaptation – that is, finding model parameters that can either quickly change the internalrepresentation (so as to maximize the sensitivity of the loss function of the new task), and/or (ii) ∗ [email protected] a r X i v : . [ c s . L G ] O c t hared-representation – that is, developing a high quality joint latent feature representations (so asto maximize the mutual information between different tasks). Motivated by this, we propose newlearning-dynamics for MAML optimization that gives better flexibility and improves the model onboth these fronts. Specifically, we apply the class of Runge-Kutta method to MAML optimization,which can take advantage of computing the gradients multiple-steps ahead when updating themeta-model. This allows us to generalize MAML to using higher-order gradients. Furthermore, weshow that the current update rule of MAML corresponds to a specific type of second-order explicitRunge-Kutta method called the midpoint method (Hairer et al., 1987).The main contribution of this work is as follows.(i) We propose a novel Runge-Kutta method for MAML optimization. This new viewpoint isadvantageous as it helps get a more refined control over the optimization. (Section 3.1)(ii) By leveraging this refined control, we demonstrate that there are multiple principled ways toupdate MAML and show that original MAML update rule corresponds to one of the class ofsecond-order Runge-Kutta methods. (Section 3.2)(iii) The Runge-Kutta framework helps understand the MAML learning dynamics from the lens ofexplicit ODE integrators. To the best of our knowledge, this is the first work that successfullyapplies ODE solvers to meta-learning. (Section 3)(iv) We show that the refinement obtained by the Runge-Kutta method is empirically effective aswell. We obtain significant improvements in performance benchmark classification, regression,and reinforcement learning tasks. (Section 5) Model-Agnostic Meta-Learning (MAML) was introduced by Finn et al. (2017) with the aim totrain a model that can adapt to a large set of new tasks with only a few data points in a fewlearning iterations. Given a meta-model f parameterized by meta-parameters θ , one wants to finda gradient learning rule can make rapid progress on a new task. This can be formalized as follows:for a task T i drawn from a distribution of tasks p ( T ) , we require that small changes in parametersusing gradient based parameter updates, that is (for any given learning rate h ) θ (cid:48) = θ − h ∇ θ L T i ( f ( θ )) , (1)would result in large improvement on loss function L of a task T i . This requirement implies thatmeta-parameter θ is an initialization that produces a more transferable internal representation andachieves good prediction with only a few examples from a new task.MAML finds such a meta-parameter by simultaneously minimizing loss functions associatedwith each task. The MAML meta-objective is defined as (Finn et al., 2017) min θ L ( f ( θ (cid:48) )) = min θ (cid:88) T i ∼ p ( T ) L T i ( f ( θ − h ∇ θ L T i ( f ( θ )))) , (2)where θ (cid:48) is as per Eq. (1), and the total loss L ( f ( θ (cid:48) )) is simply the aggregate of the individual taskspecific losses (cid:80) T i ∼ p ( T ) L T i ( f ( θ (cid:48) )) . Thus, the meta-parameter gets updated by taking a gradientdescent step in the direction that minimizes loss for all the given tasks.Algorithm 1 provides an overview of the MAML training procedure. A batch of tasks is sampledfrom the task distribution p ( T ) . The model parameters are then updated for each task (inner-updates). These updated parameters are then used to update the meta-parameter (outer-update).For simplicity and ease of subsequent discussion, we will hide the model f inside the loss L andrefer to it as L ( θ (cid:48) ) henceforth. That is, L ( θ (cid:48) ) := L T i ( θ − h ∇ θ L T i ( θ )) . One of the central objectives of this paper is understanding the learning dynamics of MAMLthrough the lens of explicit ODE integrators. Consider the vector field X that maps the space of2 lgorithm 1 MAML
Require: α and β are learning rates Require: p ( T ) is distribution over tasks Randomly initialize meta-parameter θ = θ while not done do : Sample batch of tasks T i ∼ p ( T ) for all T i do Evaluate ∇ θ L T i Update model parameter: θ (cid:48) = θ − α ∇ θ L T i ( f ( θ )) Update meta-parameter: θ = θ − β ∇ θ (cid:80) T i ∈ p ( T ) L T i ( f ( θ (cid:48) )) Algorithm 2
MAML-RK2
Require: α and β are learning rates Require: p ( T ) is distribution over tasks Randomly initialize meta-parameter θ = θ while not done do : Sample batch of tasks T i ∼ p ( T ) for all T i do Evaluate ∇ θ L T i Update model parameter: θ (cid:48) = θ − hq ∇ θ L T i ( f ( θ )) Update meta-parameter: θ = θ − h (cid:80) T i ∈ p ( T ) ( a ∇ θ L T i ( f ( θ ))) + a ∇ θ L T i ( f ( θ (cid:48) ))) parameters to the descent directions induced by the gradient of the MAML objective. So long as X is sufficiently smooth, we can look for solutions of the form dθdt = X ( θ ) . The temporal dynamics of this ODE constitutes the ideal path that the meta-learner takes duringtraining given an initial parameter θ . We can therefore use a numerical ODE solver to solveapproximate this ideal solution. Clearly a higher order explicit ODE integrator can be viewedas a black box that transforms X into a new vector field ¯ X such that the parameters θ evolvemore closely to the streamlines of X . For gradient descent, X = ∇ θ L and the resulting ¯ X advects θ more closely along the path of the steepest descent. One of the central goals of this work isto precisely to discern the efficacy of applying integrators on ¯ X rather than X . To do this wewill replace the gradient X ( θ t ) = ∇ θ L ( θ t ) with the appropriate ¯ X ( θ t ) by calling a chosen explicitODE integrator. This viewpoint generalizes the MAML optimization framework to optimize withrespect to temporal parameters in contrast to considering the spatial parameters as done in theprevious literature (Park and Oliva, 2019; Chen et al., 2018; Song et al., 2019).For a given timestep t , explicit integrator can be seen as a morphism over vector fields X → ¯ X h (for a fixed stepsize h ). Hence, for a true gradient g t = ∇ θ L ( θ t ) (at time t ), we solve the modifiedRunge-Kutta gradient ¯ g t = ∇ θ L ( θ (cid:48) t ) as follows. Define advect RK g t ( θ, h ) := θ t + ¯ g t h = θ t +1 The general form of advect RK g t ( θ, h ) is the Runge-Kutta (RK) equation of order N (Butcher, 2008),given by θ t +1 = θ t + h N (cid:88) i =1 a i k i , where (3) k := ∇ θ L ( t, θ t ) ,k := ∇ θ L ( t + p h, θ t + q k h ) ,k := ∇ θ L ( t + p h, θ t + q k h + q k h ) , ... k N := ∇ θ L ( t + p n h, θ t + q n k h + q n k h + · · · + q n,n − k n h ) , (4)where (i) a i are combination weights (that should sum to ), (ii) p j are the so called nodes thatscale the timestep (for better numerical approximation), (iii) q ij are the coefficients that scale thestep towards the gradient k i (for a fine grain optimization control), and (iv) (cid:80) i − j =1 q ij = p j for all j = 2 , . . . , N . The specific choice of these parameters ( a i , p j and q ij ) gives rise to various popularinstantiations of the RK optimization method. 3 θθ θ θ θ Figure 1: Runge-Kutta method slope k i il-lustrations. Update direction from θ to θ iscomposed of linearly combining different k i s. θ θ * MAML θ MAML-RK θ θ θ θ Trajectory along higher-order MAML-RKTrajectory along MAMLTrajectory along the original vector fieldAdaptation
Figure 2: Diagram of MAML-RK illustrates howparameters θ evolve over time. The higher-orderMAML-RK takes the path that is closer to contin-uous path X (bold red curve) that is closer to bestparameter initialization that can quickly adapts tomultiple tasks.Figure 1 demonstrates the slopes k i for a quadratic function (depicted in blue color) for ≤ i ≤ N = 4 . Notice that it takes a linear combination of the k i ’s (red arrows) to calculate the nextstep vector (green arrow). These k i ’s can be thought as forward multi-steps since they computethe gradients at future timesteps (c.f. Figure 1). Rearranging the terms with respect to ¯ g t , we get ¯ g t MAML-RK := advect RK g t ( θ, h ) − θ t h = θ t + h (cid:80) Ni =1 a i k i − θ t h = N (cid:88) i =1 a i k i . We can use this refined Runge-Kutta gradient in the MAML optimization for better convergenceand will call it as the generalized Runge-Kutta MAML (MAML-RK) method. Observe that ¯ g MAML-RK t is a linear combination of the gradients of the forward multi-steps k i over the sum ofindividual task specific loss functions L T i ( t, θ t ) . Unlike the standard MAML update, which takesthe gradient only one step forward (inner-update) , MAML-RK has the ability to take the gradientmultiple steps ahead. This encourages the meta-learner to find parameters that adapt to new taskswith very few (sometimes even just one) gradient steps on a new task, just as what we desire. Notethat although we advect g t with RK in this work, any other other explicit ODE solver can also beused.Figure 2 illustrates of the optimization path of MAML-RK. The bold blue path represents thecontinuous learning dynamics of the meta-parameter X ( θ ) . The three parameters θ , θ , and θ are the optimal parameters for three tasks, and the best meta-parameter θ ∗ is chosen in such away that it lies close to all of them. The path of standard MAML optimization (the green curve)deviates from X ( θ ) due local crude approximations. Using higher-order explicit RK methods (thered curve), follows the ideal path better due to better quality approximation.Next we show that the original MAML optimization is a special case of MAML-RK (with aspecific setting of a i , p j , q ij ) and also explore a wider setting of these parameters giving us othertypes of first and second-order optimizations for MAML. The technical key component of MAML is that it computes the gradient with respect to themeta-parameter θ while computing the objective on the updated model parameter θ (cid:48) (c.f. Equa-tion 2 and Algorithm 1). This makes the meta-parameter to move towards the direction of ∇ θ (cid:80) T i ∼ p ( T ) L T i ( θ (cid:48) i ) . Observe that this precisely corresponds to the k component of MAML-RK with a specific setting of p and q . The following proposition tells us that MAML is a specialcase of second-order MAML-RK. The meta optimization is performed over the meta-parameter θ while the objective is computed using theupdated model parameters θ (cid:48) . The original image is from Runge-Kutta Wikipedia: en.wikipedia.org/wiki/File:RungeKutta_slopes.svg roposition 1. The MAML’s gradient corresponds to the second-order explicit Runge-Kutta equa-tion with the parameters a = 0 , a = 1 , q = , p = .Proof. For simplicity, let us denote ∇ θ L ( t, θ t ) as ∇L ( t, θ t ) . Consider the Taylor expansion ofMAML’s gradient: dθdt = ∇L ( t + h, θ t − h ∇L ( t, θ t ))= ∇L ( t, θ t ) + h (cid:32) d ∇L ( t, θ t ) dt − ∂ ∇L ( t, θ t ) ∂θ ∇L ( t, θ t ) (cid:33) + O ( h ) . (5)We can compare Equation 5 with the second-order explicit Runge-Kutta Equation for MAML,which is θ t + h = θ t + ( a ∇L ( t, θ t ) + a ∇L ( t + p h, x t + q k h )) h = θ t + ( a + a ) ∇L ( t, θ t ) + a h (cid:32) d ∇L ( t, θ t ) dt p + ∂ ∇L ( t, θ t ) ∂θ ∇L ( t, θ t ) q (cid:33) + O ( h ) , (6)such that a + a = 1 , a q = , a p = .We can now see that Equation 6 equals Equation 5 with a = 0 , a = 1 , p = , and q = .The optimization done with the specific setting of the RK parameters as shown in Proposition1 is usually called the midpoint method (Butcher, 2008). This shows that the original MAMLobjective is essentially a midpoint optimization method.A setting of a = 0 and a = 1 implies that the classic MAML objective solely relies on ∇L θ ( t + p h, x t + q k h ) . Moreover, a setting of q = corresponds to a learning rate of inner-update (i.e. α in Algorithm 1) to be h and the learning rate of outer meta-update (i.e. β inAlgorithm 1) to be h . This results in a meta-model optimization to stays as close to the idealtrajectory path X ( θ ) with an error rate of O ( h ) per step (c.f. proof of Proposition 1). The Runge-Kutta gradient ¯ g MAML-RK t discussed so far has been generic. One can instantiate it (i)at various degrees of order (by choosing N ), and (ii) by varying the RK parameters in Equation 4.In this section, we examine some example instantiations of MAML-RK. Example 1 (First-order MAML) . Choose N = 1 and a = 1 . Then, ¯ g MAML-RK1 t = h ∇ θ (cid:88) T i ∼ p ( T ) L T i ( t, θ t ) . (7)Table 1: The coefficients of various nd -order RKmethods (Hairer et al., 1987)Methods a a q , p Midpoint Heun
12 12 Ralston
13 23 34
ITB
23 13 32
Generic x − x x The first-order MAML is simply the Euler’smethod. It sums over the gradients of the lossfunctions for every task. It is similar in fla-vor to other first-order approaches, such as FO-MAML (Biswas and Agrawal, 2018) and Rep-tile (Nichol et al., 2018).In the previous section, we showed thatMAML is a special case of second-orderMAML-RK by advecting g t with MAML’s gra-dient. Here, we illustrate several other popu-lar second-order Runge-Kutta methods that wecan also apply to MAML gradients, thus ex-tending the variety of optimization techniquesthat are currently in use for meta-learning. Example 2 (Second-order MAMLs) . Choose N = 2 . The methods in Table 1 satisfy the con-straints a + a = 1 , a q = , and a p = . Thus ¯ g MAML-RK2 t = a ∇ θ L ( t, θ t ) + a ∇ θ L ( t + p h, ( θ t + ∇ θ L ( t, θ t )) q h ) , (8) where the parameters a , a , p , and q can be substituted accordingly.
5e can thus generalize Algorithm 1 by substituting the learning rate α as q h , and β as h .Note that q = p for all second-order Runge-Kutta methods. Algorithm 2 presents generalizedMAML-RK2 training algorithm. It is worth noting that the existing literature only discusses themidpoint method for training MAML, and our extension enables the practitioner to explore othermethods, like Heun’s, Ralston, and ITB Hairer et al. (1987). Using the gradient ( ¯ g MAML-RK1 = ∇L ( θ ) ) from a pre-trained model that uses a large dataset andthen fine-tuning it on a smaller new dataset is popular in transfer learning and has become apopular techniques in various application domains. For example in computer vision, it is commonto use a parameters and gradients of a pre-trained network on ImageNet (Deng et al., 2009) anduse it to perform, say, bird classification (Zhang et al., 2014). Hence, there is some evidencethat pre-training with first order gradient, i.e. ¯ g MAML-RK1 helps the model learn a shared featurerepresentation that can be applied across similar tasks. Although it is worth noting that it does notencourage the model to learn a meta-parameter that can rapidly adapt to a new task. In contrast,part of the success of the classic MAML optimization ¯ g MAML (a specific second order method) isbecause it can directly optimize for this rapid new task adaptation by differentiating through thefine-tuning process with respect to the meta-parameter ∇L ( θ (cid:48) ) , but it does not make use of thefirst order gradient ∇L ( θ ) . it is important to note our second order Runge-Kutta generalization ¯ g MAML-RK2 (with specific instantiations as Heun’s, Ralston, and ITB), considers both the terms: ∇L ( θ ) and ∇L ( θ (cid:48) ) , and thus have the potential benefit of encouraging both rapid adaptation andshared feature representation (Raghu et al., 2019). See our experiments in Section 5.For every meta-learning update, the first-order method performs one evaluation of L , and thesecond-order method performs two evaluations for RK methods. The number of evaluations growslinearly up to the fourth-order, after which it grows faster making it computationally prohibitive.The fourth-order Runge-Kutta method is often the popular method for solving initial value ODEproblems. In our case, even the second-order optimization requires Hessian-vector products duringthe MAML updates and more evaluations is impractical. We will therefore limit to second orderRK methods in our experiments. Early approaches to meta-learning goes back to the late 1980s and early 1990s (Schmidhuber et al.,1987; Bengio et al., 1991) where it studies evolutionary principles in self-referential learning usinggenetic programming. There has been a recent surge in interest where meta-models are appliedto a wide array of tasks from architecture search (Zoph and Le, 2016) and hyperparameter search(Maclaurin et al., 2015), to learning optimization (Chen et al., 2017). It has become a popularapproach in few-shot supervised learning (Hariharan and Girshick, 2016) and fast reinforcementlearning (Wang et al., 2017).Among many approaches to meta-learning, Finn et al. (2017) proposed the MAML frameworkthat uses a meta-objective function that performs a two-step gradient-based optimization of themeta-parameters for fast task-specific optimization. Various studies (Li et al., 2017; Antoniouet al., 2019) show that a fast-adaptation update rule can heavily influence performance, whichhas initiated several related investigations. Biswas and Agrawal (2018) and Nichol et al. (2018)use first-order update methods to reduce computational burden, while Park and Oliva (2019),Table 2: The performance of MAML-RK for sinunoid regression tasks on 10-shot adaptationproblem. MAML-RK1 (the first-order method) corresponds to standard pre-training model on alltraining tasks. MAML-RK2 (midpoint) corresponds to stanard MAML method.Sinuoid 10-shotMAML-RK1 (pretrained) 2.72 ± ± o u r s MAML-RK2 (Heun’s) 0.19 ± ± MAML-RK2 (ITB) 0.18 ± A cc u r a cy ( % ) MAMLMAMLRK2 (Ralston) (a) 5way-1shot A cc u r a cy ( % ) MAMLMAMLRK2 (Ralston) (b) 5way-5shot A cc u r a cy ( % ) MAMLMAMLRK2 (Heun's) (c) 20way-1shot A cc u r a cy ( % ) MAMLMAMLRK2 (ITB) (d) 20way-5shot
Figure 3: Fast-adaptation of MAML-RK on Ominiglot test datsets.Table 3: The performance of MAML-RK for Omniglot and MiniImagenet image classification taskson 1- and 5-shot adaptation problems. The midpoint method corresponds to MAML.Omniglot 5-way 20-way1-shot 5-shot 1-shot 5-shotMidpoint 98.26 ± ± ± ± o u r s Heun’s 99.24 ± ± ± ± Ralston ± ± ± ± ± ± ± ± ± ± o u r s Heun’s ± ± Ralston 44.66 ± ± ± ± rapid learning , or due to feature reuse ? Their analysis shows MAML partly achieves both– large portion of the lower layers of the MAML model helps in feature reuse and large portion ofthe upper layers helps in rapid learning. Part of the goal of our work is to think about aspects of ameta-model more effectively. Our RK extension makes fast-adaptation and shared-representationmore explicit giving practitioners a fine grain control over the optimization. To study the effectiveness of our extended Runge-Kutta MAML framework, we conducted detailedempirical studies of various explicit instantiations of second order MAML-RK methods (as detailedin in Section 3) on various classification-, regression- and reinforcement- meta-learning benchmarks.Throughout the experiments, we compare the midpoint (i.e. the original MAML optimization),Heun’s, Ralston, ITB methods. The data, models, and the optimizer for all our experiments isbuilt upon the original MAML code .The following standard setup is used in all our experiments. All models are trained from thetraining task dataset and evaluated on the test task dataset. For each task, we have a supportset of K examples, which are used for fast-adaption updates. During the evaluation phase, themodel is initialized with the learned meta- parameters from training phase, and is fine-tuned onthe K samples from the test tasks. The model architecture for each experiment can be found inthe Appendix. Regression - Following the experiments in MAML (Finn et al., 2017), we consider the sinusoidregression problem. For each task, the 1-dimensional sinusoid wave of amplitude and phase arevaried between [0 . , . and [0 , π ] , and the goal is to regress on an unseen sinusoid wave. Thedatapoints are sampled from the range of [ − . , . and we used batch size of ten ( K = 10 )for every gradient update with a fixed step size of . . The mean-square-error is used as a loss MAML regression and classification code: http://github.com/cbfinn/maml and MAML reinforcement learningcode: http://github.com/cbfinn/maml_rl . .20 0.15 0.10 0.05 0.00 0.05 0.100.20.10.00.10.20.30.4 MAMLpre-update9 stepsgoal position (a) Midpoint (MAML) (b) Heun’s RALSTON pre-update9 stepsgoal position (c) Ralston (d) ITB
Figure 4: Illustration of fine-tuning using MAML-RK2 versus MAML on 2D Navigation taskfunction. Table 2 presents the regression performance. The MAML-RK1 (the first-order method)corresponds to the pre-trained model on all training tasks, which is the baseline method. Note thatMAML-RK2 (midpoint) corresponds to original MAML method. We observe that MAML-RK2(Ralston) performs marginally better than the midpoint method for this simple task.
Classification - Next, we evaluate MAML-RK on classification tasks. It has been shown thatMAML (Finn et al., 2017) achieved state-of-the-art performance when compared to prior meta-learning and few-shot learning algorithms (Koch, 2015; Vinyals et al., 2016; Ravi and Larochelle,2017) on few-shot Omniglot (Lake et al., 2011) and MiniImagenet (Ravi and Larochelle, 2017)image recognition tasks. Therefore, we only report MAML’s results against other extended MAML-RK2 methods. The standard setup for few-shot classification is that we consider and classes( N -way) with 1-shot and 5-shot learning, and evaluate on new and classes.The Omniglot images were downsampled to × and were augmented with up to 90 degreesof rotation. The training task classes were randomly selected from 1200 out of 1623 charactersand rest were used as test task classes. We followed the same model architectures from previousstudies (Vinyals et al., 2016; Finn et al., 2017). For MiniImage classification, the dataset consistsof 84 ×
84 60,000 colored images. There are 100 different classes where each class consists of 600images. Out of the 100 classes, the training, validation, and test classes were split as 64, 12, and24 respectively.Table 3 presents the performance on Omniglot dataset for 5-way and 20-way classification for1-shot and 5-shot learning. For 5-way 1-shot and 20-way 1-shot learning, we observe that Ralstonmethod performs the best, and followed by Heun’s and ITB. The midpoint method performs theworst. We suspect that this is because only the midpoint method has zero coefficient for a ,which means that it cannot take the first-order gradient information ∇L ( θ ) into account. Hence,setting the coefficient a to be greater than is important for achieving better results on thisdataset. Again for 5-way 5-shot learning, we observe that Ralston method performs the best andthe midpoint method performs the second best. Figure 3 (a-c) illustrates that Ralston outperformsMAML throughout the training process. For 20-way 5-shot learning, we found that all MAML-RK2 models perform more or less the same, where the range of mean plus and minus the standarddeviation overlaps. Heun’s method performed the best for the MiniImagenet dataset. A v e r a g e R e t u r n ( L o g S c a l e ) Point Robot, 2d Navigation
MAMLITBRALSTONHEUNrandomoracle
Figure 5: The 2D navigation performance ofMAML-RK2 over different number of gradientsteps.
Reinforcement Learning - We evaluateMAML-RK2 on two types of reinforcementlearning environments, 2D navigation and lo-comotion. For training, REINFORCE wasused for policy (Williams, 1992) and trust-region policy optimization (TRPO) was usedfor meta-optimization (Schulman et al., 2015).For 2D navigation, there are a set of taskswhere a point agent must reach to differentgoal locations in 2D space. The state, the ac-tion, and the reward corresponds to the 2D lo-cation, the motion velocity, and the distanceto the goal respectively. The simulation ter-minates when an agent navigates within 0.01distance from the goal.Figure 5 presents the performance ofMAML-RK2 for up to nine gradient updateswith 40 samples. A model trained with ran-dom initialization (black triangle line) performspoorly even with nine gradient updates. The red curve corresponds to the performance of model8ith oracle policy that receives the goal position as input. Recall that a corresponds to fastadaptation coefficient and a corresponds to good feature representation (because the higher thecoefficient, the more it focuses on ∇L ( θ (cid:48) ) and less on ∇L ( θ ) ). Interestingly, according to theplots in Figure 5 the learning speed is ordered as follows: the midpoint, Ralston, Heun’s, andITB methods. This ordering corresponds to the ordering of coefficient of a , which is , , ,and respectively. The midpoint method suffers from poor performance the most. We suspectthat this is because it only emphasizes ∇L ( θ (cid:48) ) (fast-adaptation) part and ignores ∇L ( θ ) (shared-representation) part. Ralston method seems to achieve the right balance between the two termsas the performance reaches to the oracle (red level) the fastest.Lastly, Figure 4 illustrates the actual trajectory of learning towards the final location. Theresults shows that the midpoint method takes long time to find the goal location and jitters a lot(showing suboptimal temporal dynamics). On the other hand, ITB finds the goal location withinvery few steps. Although Ralston and Heun’s are in between the midpoint and ITB, both methodsstill take much fewer steps to converge to the goal. This clearly demonstrates the role betweenoptimizing under ∇L ( θ (cid:48) ) versus ∇L ( θ ) . In this paper, we extend the fast-adaptation stage (the inner loop) to higher-order Runge-Kuttamethods in MAML to gain a finer control over the optimization, and show that original fast-adaptation update corresponds to the second-order midpoint method. The refined RK optimiza-tion helped us control various important aspects of the meta-learning process (fast adaptation andshared representation) achieving improved performance on regression, classification, and reinforce-ment learning tasks.It is worth noting that our proposed generalization is not specific to MAML, but can also beapplied to other meta-models. We share some potential directions of future work.
Exploring other ODE integrators - We applied explicit Runge-Kutta ODE integrator togeneralize stochastic gradient optimization of MAML. One can also explore other variations ongradient-based updates such as AdaGrad and ADAM (Duchi et al., 2011; Kingma and Ba, 2015)and its effects to the meta-learning models. (Similar types of analysis has been done for imageclassification via neural networks, Im et al., 2017.)Beside RK integrators, one can also apply other integrators, such as exponential- and leapfrogintegrators. Since different integrators focus on different aspects of the optimization, one expectsthat they would benefit on different types of tasks. We believe that a thorough analysis of thiswould be an interesting direction to explore in the future and would be extremely beneficial to thepractitioner.
Extension to ANIL - Raghu et al. (2019) recently showed that feature reuse is the dominantfactor in MAML optimization and propose to only train adaption updates on the last layer (i.e. thesoftmax layer for classification) of the model. It would be instructive to study how different RKmethods effects the various layers of the MAML network. For example, we can use the midpointmethod for the last layer and apply Ralston, Heuns, ITB, and the gradient for the lower layers.This essentially has the effect of shifting the balance from a to a as we move down to the lowerlayers. Extension to Bayesian MAML - There are several works on Bayesian MAML (Kim et al.,2018; Finn et al., 2018) that help in adding robustness and preventing overfitting to few shotlearning. It would be interesting to combine MAML-RK optimization with these frameworks.9 ibliography
Antreas Antoniou, Harrison Edwards, and Amos Storkey. How to train your maml. In
In Inter-national Conference on Learning Representations (ICLR) , 2019.Yoshua Bengio, Samy Bengio, and Jocelyn Cloutier. Learning a synaptic learning rule. In
Universitéde Montréal, Département d’informatique et de recherche opérationnelle , 1991.John B Biggs. The role of metalearning in study process.
British journal of educational psychology ,55(3):185–212, 1985.Abhijat Biswas and Shubham Agrawal. First-order meta-learned initialization for faster adaptationin deep reinforcement learning. 2018.John C. Butcher.
Numerical Methods for Ordinary Differential Equations . New York, 2008.Boyu Chen, Wenlian Lu, and Ernest Fokoue. Meta-learning with hessian-free approach in deepneural nets training. In arXiv preprint arXiv:1805.08462 , 2018.Yutian Chen, Matthew W. Hoffman, Sergio Gomez Colmenarejo, Misha Denil, Timothy P. Lill-icrap, Matt Botvinick, and Nando de Freitas. Learning to learn without gradient descent bygradient descent. In
Proceedings of the International Conference on Machine Learning (ICML) ,2017.Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scalehierarchical image database. In
In Computer Vision and Pattern Recognition (CVPR 2009) ,2009.John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learningand stochastic optimization.
The Journal of Machine Learning Research , 12:2121–2159, 2011.Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptationof deep networks. In arXiv preprint arXiv:1703.03400 , 2017.Chelsea Finn, Kelvin Xu, and Sergey Levine. Probabilistic model-agnostic meta-learning. In
InNeural Information Processing Systems (NIPS 2018) , 2018.Ernst Hairer, Nørsett Syvert P., and Gerhard Wanner.
Solving Ordinary Differential Equations I– Nonstiff . Springer, 1987.Bharath Hariharan and Ross Girshick. Part-based r-cnns for fine-grained category detection. In arXiv preprint arXiv:1606.02819 , 2016.Sepp Hochreiter, Steven Younger, and Peter R Conwell. Learning to learn using gradient descent.In
In Internatinoal Conference of Artificial Neural Network (ICANN 2001) , 2001.Daniel Jiwoong Im, Michael Tao, and Kristin Branson. An empirical analysis of the optimizationof deep network loss surfaces. In arXiv preprint arXiv:1612.04010 , 2017.Taesup Kim, Jaesik Yoon, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn.Bayesian model-agnostic meta-learning. In
In Neural Information Processing Systems (NIPS2018) , 2018.Diederik P. Kingma and Jimmy Lei Ba. Adam: a method for stochastic optimization. In
InInternatinoal Conference of Learning Representation (ICLR 2015) , 2015.Gregory Koch. Siamese neural networks for one-shot image recognition. In
In ICML Dep LearningWorkshop , 2015.Brenden M Lake, Ruslan Salakhutdinov, Jason Gross, and Joshua B Tenenbaum. One shot learningof simple visual concepts. In
In Conference of the Cognitive Scinece Society (CogSci) , 2011.Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li. Meta-sgd: Learning to learn quickly for fewshot learning. In arXiv preprint arXiv:1707.09835 , 2017.10ougal Maclaurin, David Duvenaud, and Ryan P. Adams. Gradient-based hyperparameter opti-mization through reversible learning. In arXiv preprint arXiv:1502.03492 , 2015.Alex Nichol, Joshua Achiam, and John Schulman. On first-order meta-learning algorithms. In
InNeural Information Processing Systems (NIPS 2018) , 2018.Eunbyung Park and Junier B. Oliva. Meta-curvature. In arXiv preprint arXiv:1902.03356 , 2019.Aniruddh Raghu, Maithra Raghu, Samy Bengio, and Oriol Vinyals. Rapid learning or featurereuse? towards undersatnding the effectiveness of maml. In arXiv preprint arXiv:1909.09157 ,2019.Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning. In
In Internati-noal Conference of Learning Representation (ICLR 2017) , 2017.Jürgen Schmidhuber, Samy Bengio, and Jocelyn Cloutier. Evolutionary principles inself—referential learning. In
Technische Universitat Miinchen , 1987.John Schulman, Sergey Levine, Pieter Abbeel, Michael I Jordan, and Philipp Mortiz. Trust regionpolicy optimization. In
In International Conference on Machine Learning (ICML 2015) , 2015.Xingyou Song, Wenbo Gao, Yuxiang Yang, Krzysztof Choromanski, Aldo Pacchiano, and YunhaoTang. Es-maml: Simple hessian-free meta learning. In arXiv preprint arXiv:1910.01215 , 2019.Oriol Vinyals, Charles Blundell, Tim Lillicrap, and Daan Wierstra. Matching networks for oneshot learning. In
In Neural Information Processing Systems (NIPS 2016) , 2016.Jane X Wang, Zeb Kurth-Nelson, Dhruva Tirumala, Hubert Soyer, Joel Z Leibo, Remi Munos,Charles Blundell, Dharshan Kumaran, and Matt Botvinick. Learning to reinforcement learn. In arXiv preprint arXiv:1611.05763 , 2017.Ronald J. Williams. Simple statistical gradient-following algorithms for connectionist reinforcementlearning.
Machine Learning , 8:229–256, 1992.Ning Zhang, Jeff Donahue, Ross Girshick, and Trevor Darrell. Part-based r-cnns for fine-grainedcategory detection. In
In European conference on computer vision (ECCV 2014) , 2014.Barret Zoph and Quoc V. Le. Neural architecture search with reinforcement learning. In arXivpreprint arXiv:1611.01578 , 2016. 11
Experimental Details
Here, we include some of the experimental set-up details. Note that we worked based on originalMAML code base. This makes the the experimental set-up the same as the original paper.
A.1 Regression
We used multilayer perceptrons with two ReLU hidden layers of size 40. We trained each modelfore 100,000 iterations. Every gradient updates are using 10 batch samples. The learning rate(step size) is set to 0.01.
A.2 Classification
Omniglot dataset
Following the same model architectures from previous studies (Vinyals et al.,2016; Finn et al., 2017), we used four convolutonal blocks with 64 × convolution filters, followedby batch normalization and ReLU activations, and × max pooling. The last hidden layer is 64.We ran 30,000 epoch for 5-way 1-shot and 5-way 5-shot set-ups for all models. We ran 50,000epoch for 20-way 1-shot set-up for all models. We ran 90,000 epoch for 20-way 5-shot set-up forall models.Because there are total of N K examples, where N is number of tasks and K is number ofexamples per task, we were able to use 32 batch sizes for each gradient updates for meta-parameters.The network was evaluated using 3 gradient updates with step size 0.4 for 5-way set-up and 5gradient updates with step size 0.1 for 20-way set-up. MiniImagenet dataset
The same convolutional blocks are used as Omniglot dataset exceptthat the number of filters were reduce to 32. The batch size was set to 4 and 2 for meta-parameterupdates. All models were trained for 90,000 epoch and 5 gradient updates with 0.01 learning rateduring the training. At test time, 10 gradient updates were used for 15 examples per class.