Learning to Stop While Learning to Predict
LLearning to Stop While Learning to Predict
Xinshi Chen Hanjun Dai Yu Li Xin Gao Le Song
Abstract
There is a recent surge of interest in designingdeep architectures based on the update steps intraditional algorithms, or learning neural networksto improve and replace traditional algorithms.While traditional algorithms have certain stop-ping criteria for outputting results at different iter-ations, many algorithm-inspired deep models arerestricted to a “fixed-depth” for all inputs. Similarto algorithms, the optimal depth of a deep architec-ture may be different for different input instances,either to avoid “over-thinking”, or because wewant to compute less for operations converged al-ready. In this paper, we tackle this varying depthproblem using a steerable architecture, where afeed-forward deep model and a variational stop-ping policy are learned together to sequentiallydetermine the optimal number of layers for eachinput instance. Training such architecture is verychallenging. We provide a variational Bayes per-spective and design a novel and effective trainingprocedure which decomposes the task into an or-acle model learning stage and an imitation stage.Experimentally, we show that the learned deepmodel along with the stopping policy improvesthe performances on a diverse set of tasks, in-cluding learning sparse recovery, few-shot metalearning, and computer vision tasks.
1. Introduction
Recently, researchers are increasingly interested in the con-nections between deep learning models and traditional algo-rithms: deep learning models are viewed as parameterizedalgorithms that operate on each input instance iteratively,and traditional algorithms are used as templates for design-ing deep learning architectures. While an important con- Georgia Institute of Technology, USA Google Research, USA King Abdullah University of Science and Technology, SaudiArabia Ant Financial, China. Correspondence to: Xinshi Chen < [email protected] > , Le Song < [email protected] > . Proceedings of the th International Conference on MachineLearning , Vienna, Austria, PMLR 108, 2020. Copyright 2020 bythe author(s).
Task 1
Fixed-depth Learned Algorithm 𝒙 … 𝑓 $ 𝑓 % 𝑓 & 𝒙 𝑻 (a): Learning-based Algorithm design Task 2 𝜃 ) 𝜃 *+,-. 𝜃 *+,-/ ∇ ℒ . ∇ ℒ / (b): Task-imbalanced Meta Learning … 𝒙 𝒕 (output) Dynamic-depth Traditional Algorithm 𝒙 𝒙 𝒕 satisfiednot satisfied criteriahand-designedupdate step D Figure 1.
Motivation for learning to stop. cept in traditional algorithms is the stopping criteria foroutputting the result, which can be either a convergence condition or an early stopping rule, such stopping criteriahas been more or less ignored in algorithm-inspired deeplearning models. A “fixed-depth” deep model is used tooperate on all problem instances (Fig. 1 (a)). Intuitively,for deep learning models, the optimal depth (or the opti-mal number of steps to operate on an input) can also bedifferent for different input instances, either because wewant to compute less for operations converged already, orwe want to generalize better by avoiding “over-thinking”.Such motivation aligns well with both the cognitive scienceliterature (Jones et al., 2009) and many examples below: • In learning to optimize (Andrychowicz et al., 2016; Li &Malik, 2016), neural networks are used as the optimizerto minimize some loss function. Depending on the initial-ization and the objective function, an optimizer shouldconverge in different number of steps; • In learning to solve statistical inverse problems such ascompressed sensing (Chen et al., 2018; Liu et al., 2019),inverse covariance estimation (Shrivastava et al., 2020),and image denoising (Zhang et al., 2019), deep mod-els are learned to directly predict the recovery results.In traditional algorithms, problem-dependent early stop-ping rules are widely used to achieve regularization for avariance-bias trade-off. Deep learning models for solvingsuch problems maybe also achieve a better recovery ac-curacy by allowing instance-specific computation steps; • In meta learning, MAML (Finn et al., 2017) used anunrolled and parametrized algorithm to adapt a common a r X i v : . [ c s . L G ] J un earning to Stop While Learning to Predict 𝒙 𝜃 𝒙 𝜃 % 𝒙 𝜙 𝜋 𝜃 - 𝜙 𝜋 %< 0.5 … 𝜃 . 𝒙 𝒕 𝜙 𝜋 .≥ 0.5 stop, output 𝒙 𝒕 … Figure 2.
Two-component model: learning to predict ( blue ) whilelearning to stopping ( green ). parameter to a new task. However, depending on the similarity of the new task to the old tasks, or, in a morerealistic task-imbalanced setting where different taskshave different numbers of data points (Fig. 1 (b)), a task-specific number of adaptation steps is more favorable toavoid under or over adaption.To address the varying depth problem, we propose to learna steerable architecture, where a shared feed-forward modelfor normal prediction and an additional stopping policyare learned together to sequentially determine the optimalnumber of layers for each input instance. In our framework,the model consists of (see Fig. 2) • A feed-forward or recurrent mapping F θ , which trans-forms the input x to generate a path of features (or states) x , · · · , x T ; and • A stopping policy π φ : ( x , x t ) (cid:55)→ π t ∈ [0 , , which se-quentially observes the states and then determines theprobability of stopping the computation of F θ at layer t .These two components allow us to sequentially predict thenext targeted state while at the same time determining whento stop . In this paper, we propose a single objective functionfor learning both θ and φ , and we interpret it from the per-spective of variational Bayes, where the stopping time t isviewed as a latent variable conditioned on the input x . Withthis interpretation, learning θ corresponds to maximizingthe marginal likelihood, and learning φ corresponds to theinference step for the latent variable, where a variationaldistribution q φ ( t ) is optimized to approximate the posterior.A natural algorithm for solving this problem could be theExpectation-Maximization (EM) algorithm, which can bevery hard to train and inefficient.How to learn θ and φ effectively and efficiently? We proposea principled and effective training procedure, where wedecompose the task into an oracle model learning stage andan imitation learning stage (Fig. 3). More specifically, • During the oracle model learning stage, we utilize aclosed-form oracle stopping distribution q ∗ | θ which canleverage label information not available at testing time. • In the imitation learning stage, we use a sequential policy π φ to mimic the behavior of the oracle policy obtained inthe first stage. The sequential policy does not have accessto the label so that it can be used during testing phase.This procedure provides us a very good initial predictive max $ 𝒥 𝛽 −VAE (ℱ 𝜃 , 𝑞 𝜙 ) Alternating Updates ℱ 𝜽 𝒒 𝝓 VAE-based method:Stage I.Our method: 𝑞 ∗ |𝜃ℱ 𝜽 min 𝝓 KL( , ) max 𝒥 𝛽 −VAE (ℱ 𝜃 , 𝑞 𝜙 ) oracle max 𝒥 𝛽 −VAE (ℱ 𝜃 , 𝑞 ∗ |𝜃) optimal 𝜽 ∗ Stage II. 𝑞 ∗ |𝜽 ∗ oracle 𝒒 𝝓 optimal 𝝓 ∗ Figure 3.
Two-stage training framework. model and a stopping policy. We can either directly use theselearned models, or plug them back to the variational EMframework and reiterate to further optimize both together.Our proposed learning to stop method is a generic frame-work that can be applied to a diverse range of applications.To summarize, our contribution in this paper includes:1. a variational Bayes perspective to understand the pro-posed model for learning both the predictive model andthe stopping policy together;2. a principled and efficient algorithm for jointly learningthe predictive model and the stopping policy; and therelation of this algorithm to reinforcement learning;3. promising experiments on various tasks including learn-ing to solve sparse recovery problems, task-imbalancedfew-shot meta learning, and computer vision tasks, wherewe demonstrate the effectiveness of our method in termsof both the prediction accuracy and inference efficiency.
2. Related Works
Unrolled algorithm.
A line of recent works unfold andtruncate iterative algorithms to design neural architectures.These algorithm-based deep models can be used to automat-ically learn a better algorithm from data. This idea has beendemonstrated in different problems including sparse signalrecovery (Gregor & LeCun, 2010; Sun et al., 2016; Borg-erding et al., 2017; Metzler et al., 2017; Zhang & Ghanem,2018; Chen et al., 2018; Liu et al., 2019), sparse inversecovariance estimation (Shrivastava et al., 2020), sequentialBayesian inference (Chen et al., 2019), parameter learningin graphical models (Domke, 2011), non-negative matrixfactorization (Yakar et al., 2013), etc. Unrolled algorithmbased deep module has also be used for structured prediction(Belanger et al., 2017; Ingraham et al., 2019; Chen et al.,2020). Before the training phase, all these works need toassign a fixed number of iterations that is used for every earning to Stop While Learning to Predict input instance regardless of their varying difficulty level.Our proposed method is orthogonal and complementary toall these works, by taking the variety of the input instancesinto account via adaptive stopping time.
Meta learning.
Optimization-based meta learning techniq-ues are widely applied for solving challenging few-shotlearning problems (Ravi & Larochelle, 2017; Finn et al.,2017; Li et al., 2017). Several recent advances proposedtask-adaptive meta-learning models which incorporate task-specific parameters (Qiao et al., 2018; Lee & Choi, 2018;Na et al., 2020) or task-dependent metric scaling (Oreshkinet al., 2018). In parallel with these task-adaptive methods,we propose a task-specific number of adaptation steps anddemonstrate the effectiveness of this simple modificationunder the task-imbalanced scenarios.
Other adaptive-depth deep models.
In image recognition,‘early exits’ is proposed mainly aimed at improving thecomputation efficiency during the inference phase (Teer-apittayanon et al., 2016; Zamir et al., 2017; Huang et al.,2018), but these methods are based on specific architectures.Kaya et al. (2019) proposed to avoiding “over-thinking” byearly stopping. However, the same as all the other ‘earlyexits’ models, some heuristic policies are adopted to choosethe output layer by confidence scores of internal classifiers.Also, their algorithms for training the feed-forward model F θ do not take into account the effect of the stopping policy. Optimal stopping.
In optimal control literature, optimalstopping is a problem of choosing a time to take a given ac-tion based on sequentially observed random variables in or-der to maximize an expected payoff (Shiryaev, 2007). Whena policy for controlling the evolution of random variables(corresponds to the output of F θ ) is also involved, it is calleda “mixed control” problem, which is highly related to ourwork. Existing works in this area find the optimal controlsby solving the Hamilton-Jacobi-Bellman (HJB) equation,which is theoretically grounded (Pham, 1998; Ceci & Bas-san, 2004; Dumitrescu et al., 2018). However, they focuson stochastic differential equation based model and the pro-posed algorithms suffer from the curse of dimensionalityproblem. Becker et al. (2019) use DL to learn the optimalstopping policy, but the learning of θ is not considered. Be-sides, Becker et al. (2019) use reinforcement learning (RL)to solve the problem. In Section 4, we will discuss how ourvariational inference formulation is related to RL.
3. Problem Formulation
In this section, we will introduce how we model the stoppingpolicy together with the predictive deep model, define thejoint optimization objective, and interpret this frameworkfrom a variational Bayes perspective.
The predictive model, F θ , is a typical T -layer deep modelthat generates a path of embeddings ( x , · · · , x T ) through: Predictive model: x t = f θ t ( x t − ) , for t = 1 , · · · , T (1)where the initial x is determined by the input x . We denoteit by F θ = { f θ , · · · , f θ T } where θ ∈ Θ are the parameters.Standard supervised learning methods learn θ by optimizingan objective estimated on the final state x T . In our model,the operations in Eq. 1 can be stopped earlier, and for differ-ent input instance x , the stopping time t can be different.Our stopping policy, π φ , determines whether to stop at t -thstep after observing the input x and its first t states x t transformed by F θ . If we assume the Markov property, then π φ only needs to observe the most recent state x t . In thispaper, we only input x and x t to π φ at each step t , but itis trivial to generalize it π π ( x , x t ) . More precisely, π φ isdefined as a randomized policy as follows: Stopping policy: π t = π φ ( x , x t ) , for t = 1 , · · · , T − (2)where π t ∈ [0 , is the probability of stopping. We abusethe notation π to both represent the parametrized policy andalso the probability mass.This stopping policy sequentially makes a decision when-ever a new state x t is observed. Conditioned on the statesobserved until step t , whether to stop before t is independenton states after t . Therefore, once it decides to stop at t , theremaining computations can be saved, which is a favorableproperty when the inference time is a concern, or for someoptimal stopping problems such as option trading wheregetting back to earlier states is not allowed. The stopping policy π φ makes sequential actions based onthe observations, where π t := π φ ( x , x t ) is the probabilityof stopping when x t is observed. These sequential actions π , · · · , π T − jointly determines the random time t at whichthe stop occurs. Induced by π φ , the probability mass func-tion of the stop time t , denoted as q φ , can be computed by Variational stop time distribution: (cid:40) q φ ( t ) = π t (cid:81) t − τ =1 (1 − π τ ) if t < T,q φ ( T ) = (cid:81) T − τ =1 (1 − π τ ) else . (3)In this equation, the product (cid:81) t − τ =1 (1 − π τ ) indicates theprobability of ‘not stopped before t ’, which is the survivalprobability. Multiply this survival probability with π t , wehave the stop time distribution q φ ( t ) . For the last time step T , the stop probability q φ ( T ) simply equals to the survivalprobability at T , which means if the process is ‘not stoppedbefore T ’, then it must stop at T . earning to Stop While Learning to Predict Note that we only use π φ in our model to sequentially de-termine whether to stop. However, we use the inducedprobability mass q φ to help design the training objective andalso the algorithm. Note that the stop time t is a discrete random variable withdistribution determined by q φ ( t ) . Given the observed label y of an input x , the loss of the predictive model stopped atposition t can computed as (cid:96) ( y , x t ; θ ) where (cid:96) ( · ) is a lossfunction. Taking into account all possible stopping positions,we will be interested in the loss in expectation over t , L ( θ, q φ ; x , y ) := E t ∼ q φ (cid:96) ( y , x t ; θ ) − βH ( q φ ) , (4)where H ( q φ ) := − (cid:80) t q φ ( t ) log q φ ( t ) is an entropy regu-larization and β is the regularization coefficient. Given adata set D = { ( x , y ) } , the parameters of the predictivemodel and the stopping policy can be estimated by min θ,φ |D| (cid:80) ( x , y ) ∈D L ( θ, q φ ; x , y ) . (5)To better interpret the model and objective, in the following,we will make a connection from the perspective of vari-ational Bayes, and how the objective function defined inEq. 4 is equivalent to the β -VAE objective. In the Bayes’ framework, a probabilistic model typicallyconsists of prior, likelihood function and posterior of thelatent variable. We find the correspondence between ourmodel and a probabilistic model as follows (also see Table 1) • we view the adaptive stopping time t as a latent variable which is unobserved; • The conditional prior p ( t | x ) of t is a uniform distributionover all the layers in this paper. However, if one wantsto reduce the computation cost and penalize the stoppingdecisions at deeper layers, a prior with smaller probabilityon deeper layers can be defined to regularize the results; • The likelihood function p θ ( y | t, x ) of the observed label y is controlled by θ , since F θ determines the states x t ; • The posterior distribution over the stopping time t can becomputed by Bayes’ rule p θ ( t | y , x ) ∝ p θ ( y | t, x ) p ( t | x ) ,but it requires the observation of the label y , which isinfeasible during testing phase. Table 1.
Corresponds between our model and Bayes’ model. stop time t latent variablelabel y observationloss (cid:96) ( y , x t ; θ ) likelihood p θ ( y | t, x ) stop time distribution q φ posterior p θ ( t | y , x ) regularization prior p ( t | x ) In this probabilistic model, we need to learn θ to better fitthe observed data and learn a variational distribution q φ over t that only takes x and the transformed internal states asinputs to approximate the true posterior.More specifically, the parameters in the likelihood functionand the variational posterior can be optimized using the vari-ational autoencoder (VAE) framework (Kingma & Welling,2013). Here we consider a generalized version called β -VAE (Higgins et al., 2017), and obtain the optimizationobjective for data point ( x , y ) J β -VAE ( θ, q φ ; x , y ) := E q φ log p θ ( y | t, x ) − β KL ( q φ ( t ) || p ( t | x )) , (6)where KL ( ·||· ) is the KL divergence. When β = 1 , itbecomes the original VAE objective, i.e., the evidence lowerbound (ELBO). Now we are ready to present the equivalencerelation between the β -VAE objective and the loss definedin Eq. 4. See Appendix A.1 for the proof. Lemma 1.
Under assumptions: (i) the loss function (cid:96) inEq. 4 is defined as the negative log-likelihood (NLL), i.e., (cid:96) ( y , x t ; θ ) := − log p θ ( y | t, x ); (ii) the prior p ( t | x ) is a uniform distribution over t ;then minimizing the loss L in Eq. 4 is equivalent to maxi-mizing the β -VAE objective J β -VAE in Eq. 6. For classification problems, the cross-entropy loss is alignedwith NLL. For regression problems with mean squared error(MSE) loss, we can define the likelihood as p θ ( y | t, x ) ∼N ( x t , I ) . Then the NLL of this Gaussian distribution is − log p θ ( y | t, x ) = (cid:107) y − x t (cid:107) + C , which is equiva-lent to MSE loss. More generally, we can always define p θ ( y | t, x ) ∝ exp( − (cid:96) ( y , x t ; θ )) .This VAE view allows us to design a two-step procedureto effectively learn θ and φ in the predictive model andstopping policy, which is presented in the next section.
4. Effective Training Algorithm
VAE-based methods perform optimization steps over θ (Mstep for learning) and φ (E step for inference) alternativelyuntil convergence, which has two limitations in our case:i. The alternating training can be slow to converge andrequires tuning the training scheduling;ii. The inference step for learning q φ may have the mode col-lapse problem, which in this case means q φ only capturesthe time step t with highest averaged frequency.To overcome these limitations, we design a training proce-dure followed by an optional fine-tuning stage using thevariational lower bound in Eq. 6. More specifically,Stage I. Find the optimal θ by maximizing the conditionalmariginal likelihood when the stop time distribution followsan oracle distribution q ∗ θ . earning to Stop While Learning to Predict Stage II. Fix the optimal θ learned in Stage I, and only learnthe distribution q φ to mimic the oracle by minimizing theKL divergence between q φ and q ∗ θ .Stage III. (Optional) Fine-tune θ and φ jointly towards thejoint objective in Eq. 6.The overall algorithm steps are summarized in Algorithm 1.In the following sections, we will focus on the derivationof the first two training steps. Then we will discuss severalmethods to further improve the memory and computationefficiency for training. We first give the definition of the oracle stop time distribu-tion q ∗ θ . For each fixed θ , we can find a closed-form solutionfor the optimal q ∗ θ that optimizes the joint objective. q ∗ θ ( ·| y , x ) := arg max q ∈ ∆ T − J β -VAE ( θ, q ; x , y ) Alternatively, q ∗ θ ( ·| y , x ) = arg min q ∈ ∆ T − L ( θ, q ; x , y ) .Under the mild assumptions in Lemma 1, these two opti-mizations lead to the same optimal oracle distribution. Oracle stop time distribution: q ∗ θ ( t | y , x ) = p θ ( y | t, x ) β (cid:80) Tt =1 p θ ( y | t, x ) β (7) = exp( − β (cid:96) ( y , x t ; θ )) (cid:80) Tt =1 exp( − β (cid:96) ( y , x t ; θ )) (8)This closed-form solution makes it clear that the oracle picksa step t according to the smallest loss or largest likelihoodwith an exploration coefficient β . Remark : When β = 1 , q ∗ θ is the same as the posteriordistribution p θ ( t | y , x ) ∝ p θ ( y | t, x ) p ( t | x ) .Note that there are no new parameters in the oracle dis-tribution. Instead, it depends on the parameters θ in thepredictive model. Overall, the oracle q ∗ θ is a function of θ , t , y and x that has a closed-form. Next, we will introducehow we use this oracle in the first two training stages. In Stage I, we optimize the parameters θ in the predictivemodel by taking into account the oracle stop distribution q ∗ θ . This step corresponds to the M step for learning θ , bymaximizing the marginal likelihood. The difference withthe normal M step is that here q φ is replaced by the oracle q ∗ θ that gives the optimal stopping distribution so that themarginal likelihood is independent on φ . More precisely,stage I finds the optimum of: max θ |D| (cid:88) ( x , y ) ∈D T (cid:88) t =1 q ∗ θ ( t | y , x ) log p θ ( y | t, x ) , (9) Algorithm 1
Overall AlgorithmRandomly initialized θ and φ . For itr = 1 to do (cid:46) Stage I.Sample a batch of data points
B ∼ D .Take an optimization step to update θ towards themarginal likelihood function defined in Eq. 9. For itr = 1 to do (cid:46) Stage II.Sample a batch of data points
B ∼ D .Take an optimization step to update φ towards the re-verse KL divergence defined in Eq. 10. For itr = 1 to do (cid:46) Optional StepSample a batch of data points
B ∼ D .Update both θ and φ towards β -VAE objective in Eq. 6. return θ , φ where the summation over t is the expectation of the like-lihood, E t ∼ q ∗ θ ( t | y , x ) log p θ ( y | t, x ) . Since q ∗ θ has a differ-entiable closed-form expression in terms of θ, x , y and t ,the gradient can also propagate through q ∗ θ , which is alsodifferent from the normal M step.To summarize, in Stage I., we learn the predictive modelparameter θ , by assuming that the stop time always followsthe best stopping distribution that depends on θ . In this case,the learning of θ has already taken into account the effect ofthe data-specific stop time.However, we note that the oracle q ∗ θ is not in the form ofsequential actions as in Eq. 2 and it requires the access tothe true label y , so it can not be used for testing. However,it plays an important role in obtaining a sequential policywhich will be explained next. In Stage II, we learn the sequential policy π φ that can bestmimic the oracle distribution q ∗ θ , where θ is fixed to bethe optimal θ learned in Stage I. The way of doing so isto minimize the divergence between the oracle q ∗ θ and thevariational stop time distribution q φ induced by π φ (Eq. 3).There are various variational divergence minimization ap-proaches that we can use (Nowozin et al., 2016). For exam-ple, a widely used objective for variational inference is the reverse KL divergence :KL ( q φ || q ∗ θ ) = (cid:80) Tt =1 − q φ ( t ) log q ∗ θ ( t | y , x ) − H ( q φ ) . Remark.
We write q φ ( t ) instead of q φ ( t | x T , x ) for nota-tion simplicity, but q φ is dependent on x and x T (Eq. 3).If we rewrite q φ using π , · · · , π T − as defined in Eq. 3,we can find that minimizing the reverse KL is equivalent tofinding the optimal policy π φ in a reinforcement learning(RL) environment, where the state is x t , action a t ∼ π t := π φ ( x , x t ) is a stop/continue decision, the state transition is earning to Stop While Learning to Predict determined by θ and a t , and the reward is defined as r ( x t , a t ; y ) := (cid:40) − β(cid:96) ( y , x t ; θ ) if a t = 0 (i.e. stop) if a t = 1 (i.e. continue)where (cid:96) ( y , x t ; θ ) = − log p θ ( y | t, x ) . More detials and alsothe derivation are given in Appendix A.2 to show that min-imizing KL ( q φ || q ∗ θ ) is equivalent to solving the following maximum-entropy RL : max φ E π φ (cid:80) Tt =1 [ r ( x t , a t ; y ) + H ( π t )] . In some related literature, optimal stopping problem is oftenformulated as an RL problem (Becker et al., 2019). Abovewe bridge the connection between our variational inferenceformulation and the RL-based optimal stopping literature.Although reverse KL divergence is a widely used objective,it suffers from the mode collapse issue, which in our casemay lead to a distribution q φ that captures only a commonstopping time t for all x that on average performs the best,instead of a more spread-out stopping time. Therefore, weconsider the forward KL divergence :KL ( q φ || q ∗ θ ) = − T (cid:88) t =1 q ∗ θ ( t | y , x ) log q φ ( t ) − H ( q ∗ θ ) , (10)which is equivalent to the cross-entropy loss, since the term H ( q ∗ θ ) can be ignored as θ is fixed in this step. Experimen-tally, we find forward KL leads to a better performance. It is easy to see that our two-stage training procedure alsohas an EM flavor. However, with the oracle q ∗ θ incorporated,the training of θ has already taken into account the effect ofthe optimal stopping distribution. Therefore, we can save alot of alternation steps. After the two-stage training, we canfine-tune θ and φ jointly towards the β -VAE objective. Ex-perimentally, we find this additional stage does not improvemuch the performance trained after the first two stages. Since both objectives in oracle learning stage (Eq. 9) andimitation stage (Eq. 10) involve the summation over T lay-ers, the computation and memory costs during training arehigher than standard learning methods. The memory issue isespecially important in meta learning. In the following, weintroduce several ways of improving the training efficiency. Fewer output channels.
Instead of allowing the model tooutput x t at any layer, we can choose a smaller number ofoutput channels that are evenly placed along with the layers. Stochastic sampling in Step I.
A Monte Carlo method canbe used to approximate the expectation over q ∗ θ in Step I. More precisely, for each ( x , y ) we can randomly samplea layer t s ∼ q ∗ θ ( t | y , x ) from the oracle, and only compute log p θ ( y | t s , x ) at t s , instead of summing over all t ∈ [ T ] .Note that, in this case, the gradient will not back-propagatethrough q ∗ θ ( t | y , x ) . MAP estimate in Step II.
Instead of approximating thedistribution q ∗ θ , we can approximate the maximum a pos-terior (MAP) estimate ˆ t ( x , y ) = arg max t ∈ [ T ] q ∗ θ ( t | y , x ) so that the objective for each sample is − log q θ (ˆ t ( x , y )) ,which does not involve the summation over t . Except forefficiency, we also find this MAP estimate can lead to ahigher accuracy, by encouraging the learning of q φ to focusmore on the sample-wise best layer.
5. Experiments
We conduct experiments on (i) learning-based algorithm forsparse recovery, (ii) few-shot meta learning, and (iii) imagedenoising. The comparison is in an ablation study fashion tobetter examine whether the stopping policy can improve theperformances given the same architecture for the predictivemodel, and whether our training algorithm is more effectivecompared to the alternating EM algorithm. In the end, wealso discuss our exploration of the image recognition task.
We consider a sparse recovery task which aims at recovering x ∗ ∈ R n from its noisy linear measurements b = A x ∗ + (cid:15) ,where A ∈ R m × n , (cid:15) ∈ R m is Gaussian white noise, and m (cid:28) n . A popular approach is to model the problem as theLASSO formulation min x (cid:107) b − A x (cid:107) + ρ (cid:107) x (cid:107) and solvesit using iterative methods such as the ISTA (Blumensath &Davies, 2008) and FISTA (Beck & Teboulle, 2009) algo-rithms. We choose the most popular model named LearnedISTA (LISTA) as the baseline and also as our predictivemodel. LISTA is a T -layer network with update steps: x t = η λ t ( W t b + W t x t − ) , t = 1 , · · · , T, (11)where θ = { ( λ t , W t , W t ) } Tt =1 are leanable parameters. Experiment setting.
We follow Chen et al. (2018) to gen-erate the samples. The signal-to-noise ratio (SNR) for eachsample is uniformly sampled from 20, 30, and 40. The train-ing loss for LISTA is (cid:80) Tt =1 γ T − t (cid:107) x t − x ∗ (cid:107) where γ ≤ .It is commonly used for algorithm-based deep learning, sothat there is a supervision signal for every layer. For ISTAand FISTA, we use the training set to tune the hyperparam-eters by grid search. See Appendix B.1 for more details. Recovery performance. (Table 2) We report the NMSE(in dB) results for each model/algorithm evaluated on 1000fixed test samples per SNR level. It is revealed in Table 2that learning-based methods have better recovery perfor- earning to Stop While Learning to Predict
Table 2.
Recovery performances of different algorithms/models.
SNR mixed 20 30 40FISTA ( T = 100) -18.96 -16.75 -20.46 -20.97ISTA ( T = 100) -14.66 -13.99 -14.99 -15.07ISTA ( T = 20) -9.17 -9.12 -9.24 -9.16FISTA ( T = 20) -11.12 -10.98 -11.19 -11.19LISTA ( T = 20) -17.53 -16.53 -18.07 -18.20 LISTA-stop ( T (cid:54) -22.41 -20.29 -23.90 -24.21 mances, especially for the more difficult tasks (i.e. whenSNR is 20). Compared to LISTA, our proposed adaptive-stopping method (LISTA-stop) significantly improve recov-ery performance. Also, LISTA-stop with (cid:54) iterationsperforms better than ISTA and FISTA with 100 iterations,which indicates a better convergence. Stopping distribution.
The stop time distribution q φ ( t ) in-duced by π φ can be computed via Eq. 3. We report in Fig. 4the stopping distribution averaged over the test samples,from which we can see that with a high probability LISTA-stop terminates the process before arriving at 20-th iteration.
14 15 16 17 18 19 20 iteration t q ( t ) iteration t a v e r a g e d N M S E LISTA-stopLISTAISTAFISTA (a) stop time distribution (b) convergence
Figure 4.
Left : Stop time distribution |D test | (cid:80) x ∈D test q φ ( t | x ) averaged over the test set. Right : Convergence of different al-gorithms. For LISTA-stop, the NMSE weighted by the stoppingdistribution q φ is plotted. In the first 13 iterations q φ ( t ) = 0 , sono red dots are plotted. Convergence comparison.
Fig. 4 shows the change ofNMSE as the number of iterations increases. Since LISTA-stop outputs the results at different iteration steps, it is notmeaningful to draw a unified convergence curve. Therefore,we plot the NMSE weighted by the stopping distribution q φ ,i.e.,
10 log ( (cid:80) Ni =1 q φ ( t | i ) (cid:107) x t − x ∗ ,i (cid:107) (cid:80) Ni =1 q φ ( t | i ) / ( (cid:80) Ni =1 (cid:107) x ∗ ,i (cid:107) N ) , usingthe red dots. We observe that for LISTA-stop the expectedNMSE increases as the number of iterations increase, thismight indicate that the later stopped problems are moredifficult to solve. Besides, at 15th iteration, the NMSE inFig. 4 (b) is the smallest, while the averaged stop probabilitymass q φ (15) in Fig. 4 (a) is the highest. Table 3.
Different algorithms for training LISTA-stop.
SNR mixed 20 30 40AEVB algorithm -21.92 -19.92 -23.27 -23.58Stage I. + II. -22.41 -20.29 -23.90 -24.21
Stage I.+II.+III. -22.78 -20.59 -24.29 -24.73
Ablation study on training algorithms.
To show the ef-fectiveness of our two-stage training, in Table 3, we com-pare the results with the auto-encoding variational Bayes(AEVB) algorithm (Le et al., 2018) that jointly optimizes F θ and q φ . We observe that the distribution q φ in AEVBgradually becomes concentrated on one layer and does notget rid of this local minimum, making its final result not asgood as the results of our two-stage training. Moreover, it isrevealed that Stage III does not improve much of the perfor-mance of the two-stage training, which also in turn showsthe effectiveness of the oracle-based two-stage training. In this section, we perform meta learning experiments in thefew-short learning domain (Ravi & Larochelle, 2017).
Experiment setting.
We follow the setting in MAML (Finnet al., 2017) for the few-shot learning tasks. Each task is anN-way classification that contains meta- { train, valid, test } sets. On top of it, the macro dataset with multiple tasks issplit into train, valid and test sets. We consider the more re-alistic task-imbalanced setting proposed by Na et al. (2020).Unlike the standard setting where the meta-train of eachtask contains k -shots for each class, here we vary the num-ber of observation to perform k - k -shot learning where k < k are the minimum/maximum number of observa-tions per class, respectively. Build on top of MAML, wedenote our variant as MAML-stop which learns how manyadaptation gradient descent steps are needed for each task.Intuitively, the tasks with less training data would preferfewer steps of gradient-update to prevent overfitting. As wemainly focus on the effect of learning to stop, the neuralarchitecture and other hyperparameters are largely the sameas MAML. Please refer to Appendix B.2 for more details. Dataset.
We use the benchmark datasets Omniglot (Lakeet al., 2011) and MiniImagenet (Ravi & Larochelle, 2017).Omniglot consists of 20 instances of 1623 characters from50 different alphabets, while MiniImagenet involves 64training classes, 12 validation classes, and 24 test classes.We use exactly the same data split as Finn et al. (2017). Toconstruct the imbalanced tasks, we perform 20-way 1-5shot classification on Omniglot and 5-way 1-10 shot clas-sification on MiniImagenet. The number of observationsper class in each meta-test set is 1 and 5 for Omniglot andMiniImagenet, respectively. For evaluation, we construct600 tasks from the held-out test set for each setting.
Table 4.
Task-imbalanced few-shot image classification.
Omniglot MiniImagenet20-way, 1-5 shot 5-way, 1-10 shotMAML 97.96 ± ± . ± . % . ± . % Results.
Table 4 summarizes the accuracy and the 95% con- earning to Stop While Learning to Predict
Table 5.
Few-shot classification in vanilla meta learning setting (Finn et al., 2017) where all tasks have the same number of data points.
Omniglot 5-way Omniglot 20-way MiniImagenet 5-way1-shot 5-shot 1-shot 5-shot 1-shot 5-shotMAML 98.7 ± ± ± ± ± ± ± ± ± ± ± ± fidence interval on the held-out tasks for each dataset. Themaximum number of adaptation gradient descent steps is 10for both MAML and MAML-stop. We can see the optimalstopping variant of MAML outperforms the vanilla MAMLconsistently. For a more difficult task on MiniImagenetwhere the imbalance issue is more severe, the accuracyimprovement is 3.5%. For completeness, we include theperformance on vanilla meta learning setting where all taskshave the same number of observations in Table 5. MAML-stop still achieves comparable or better performance. In this section, we perform the image denoising experiments.More implementation details are provided in Appendix B.3.
Dataset.
The models are trained on BSD500 (400 images)(Arbelaez et al., 2010), validated on BSD12, and tested onBSD68 (Martin et al., 2001). We follow the standard settingin (Zhang et al., 2019; Lefkimmiatis, 2018; Zhang et al.,2017) to add Gaussian noise to the images with a randomnoise level σ (cid:54) during training and validation phases. Experiment setting.
We compare with two DL models,DnCNN (Zhang et al., 2017) and UNLNet (Lefkimmiatis,2018), and two traditional methods, BM3D (Dabov et al.,2007) and WNNM (Gu et al., 2014). Since DnCNN is oneof the most widely-used models for image denoising, weuse it as our predictive model. All deep models includingours are considered in the blind Gaussian denoising setting,which means the noise-level is not given to the model, whileBM3D and WNNM require the noise-level to be known.
Table 6.
PSNA performance comparison. The sign * indicates thatnoise levels 65 and 75 do not appear in the training set. σ DnCNN-stop DnCNN UNLNet BM3D WNNM *75 The performance is evaluated by the mean peaksignal-to-noise ratio (PSNR). Table 6 shows that DnCNN-stop performs better than the original DnCNN. Especially,for images with noise levels 65 and 75 which are unseen dur-ing training phase, DnCNN-stop generalizes significantlybetter than DnCNN alone. Since there is no released codefor UNLNet , its performances are copied from the pa- per (Lefkimmiatis, 2018), where results are not reportedfor σ = 65 and . For traditional methods BM3D andWNNM, the test is in the noise-specific setting. That is,the noise level is given to both BM3D and WNNM, so thecomparison is not completely fair to learning based methodsin blind denoising setting. Ground Truth WNNMDnCNN DnCNN-stop
Figure 5.
Denoising results of an image with noise level 65. (SeeAppendix B.3.2 for more visualization results.)
We explore the potential of our idea for improving the recog-nition performances on Tiny-ImageNet, using VGG16 (Si-monyan & Zisserman, 2014) as the predictive model. With14 internal classifiers, after Stage I training, if the oracle q ∗ θ is used to determine the stop time t , the accuracy of VGG16can be improved to 83.26%. Similar observation is providedin SDN (Kaya et al., 2019), but their loss (cid:80) t w t (cid:96) t dependson very careful hand-tuning on the weight w t for each layer,while we directly take an expectation using the oracle, whichis more principled and leads to higher accuracy (Table 7).However, it reveals to be very hard to mimic the behavior ofthe orcale q ∗ θ by π φ in Stage II, either due to the need of abetter parametrization for π φ or more sophisticated reasons.Our learned π φ leads to similar accuracy as the heuristicpolicy in SDN, which becomes the bottleneck in our ex-ploration. However, based on the large performance gapbetween the oracle and the original VGG16, our result stillprovides a potential direction for breaking the performancebottleneck of DL on image recognition. Table 7.
Image recognition with oracle stop distribution.
VGG16 SDN training Our Stage I. training58.60% 77.78% (best layer) 83.26% (best layer) earning to Stop While Learning to Predict
6. Conclusion
In this paper, we introduce a generic framework for mod-elling and training a deep learning model with input-specificdepth, which is determined by a stopping policy π φ . Ex-tensive experiments are conducted to demonstrate the ef-fectiveness of both the model and the training algorithm,on a wide range of applications. In the future, it will beinteresting to see whether other aspects of algorithms canbe incorporated into deep learning models either to improvethe performance or for better theoretical understandings. References
Andrychowicz, M., Denil, M., Gomez, S., Hoffman, M. W.,Pfau, D., Schaul, T., Shillingford, B., and De Freitas, N.Learning to learn by gradient descent by gradient descent.In
Advances in Neural Information Processing Systems ,pp. 3981–3989, 2016.Arbelaez, P., Maire, M., Fowlkes, C., and Malik, J. Contourdetection and hierarchical image segmentation.
IEEEtransactions on pattern analysis and machine intelligence ,33(5):898–916, 2010.Beck, A. and Teboulle, M. A fast iterative shrinkage-thresholding algorithm for linear inverse problems.
SIAMjournal on imaging sciences , 2(1):183–202, 2009.Becker, S., Cheridito, P., and Jentzen, A. Deep optimalstopping.
Journal of Machine Learning Research , 20(74):1–25, 2019.Belanger, D., Yang, B., and McCallum, A. End-to-end learn-ing for structured prediction energy networks. In
Proceed-ings of the 34th International Conference on MachineLearning-Volume 70 , pp. 429–439. JMLR. org, 2017.Blumensath, T. and Davies, M. E. Iterative thresholding forsparse approximations.
Journal of Fourier analysis andApplications , 14(5-6):629–654, 2008.Borgerding, M., Schniter, P., and Rangan, S. Amp-inspireddeep networks for sparse linear inverse problems.
IEEETransactions on Signal Processing , 65(16):4293–4308,2017.Ceci, C. and Bassan, B. Mixed optimal stopping and stochas-tic control problems with semicontinuous final reward fordiffusion processes.
Stochastics and Stochastic Reports ,76(4):323–337, 2004.Chen, X., Liu, J., Wang, Z., and Yin, W. Theoretical linearconvergence of unfolded ista and its practical weights andthresholds. In
Advances in Neural Information Process-ing Systems , pp. 9061–9071, 2018. Chen, X., Dai, H., and Song, L. Particle flow bayes rule.In
International Conference on Machine Learning , pp.1022–1031, 2019.Chen, X., Li, Y., Umarov, R., Gao, X., and Song, L. RNAsecondary structure prediction by learning unrolled algo-rithms. arXiv preprint arXiv:2002.05810 , 2020.Dabov, K., Foi, A., Katkovnik, V., and Egiazarian, K. Imagedenoising by sparse 3-d transform-domain collaborativefiltering.
IEEE Transactions on image processing , 16(8):2080–2095, 2007.Domke, J. Parameter learning with truncated message-passing. In
CVPR 2011 , pp. 2937–2943. IEEE, 2011.Dumitrescu, R., Reisinger, C., and Zhang, Y. Approximationschemes for mixed optimal stopping and control problemswith nonlinear expectations and jumps. arXiv preprintarXiv:1803.03794 , 2018.Finn, C., Abbeel, P., and Levine, S. Model-agnostic meta-learning for fast adaptation of deep networks. In
Proceed-ings of the 34th International Conference on MachineLearning-Volume 70 , pp. 1126–1135. JMLR. org, 2017.Gregor, K. and LeCun, Y. Learning fast approximationsof sparse coding. In
Proceedings of the 27th Interna-tional Conference on International Conference on Ma-chine Learning , pp. 399–406. Omnipress, 2010.Gu, S., Zhang, L., Zuo, W., and Feng, X. Weighted nuclearnorm minimization with application to image denoising.In
Proceedings of the IEEE conference on computer vi-sion and pattern recognition , pp. 2862–2869, 2014.Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X.,Botvinick, M., Mohamed, S., and Lerchner, A. beta-VAE: Learning basic visual concepts with a constrainedvariational framework.
ICLR , 2(5):6, 2017.Huang, G., Chen, D., Li, T., Wu, F., van der Maaten, L., andWeinberger, K. Multi-scale dense networks for resourceefficient image classification. In
International Conferenceon Learning Representations , 2018. URL https://openreview.net/forum?id=Hk2aImxAb .Ingraham, J., Riesselman, A., Sander, C., and Marks, D.Learning protein structure with a differentiable simulator.In
International Conference on Learning Representations ,2019. URL https://openreview.net/forum?id=Byg3y3C9Km .Jones, M., Kinoshita, S., and Mozer, M. C. Optimal re-sponse initiation: Why recent experience matters. In
Advances in neural information processing systems , pp.785–792, 2009. earning to Stop While Learning to Predict
Kaya, Y., Hong, S., and Dumitras, T. Shallow-deep net-works: Understanding and mitigating network overthink-ing. In
International Conference on Machine Learning ,pp. 3301–3310, 2019.Kingma, D. P. and Welling, M. Auto-encoding variationalbayes. arXiv preprint arXiv:1312.6114 , 2013.Lake, B., Salakhutdinov, R., Gross, J., and Tenenbaum, J.One shot learning of simple visual concepts. In
Proceed-ings of the annual meeting of the cognitive science society ,volume 33, 2011.Le, T. A., Igl, M., Rainforth, T., Jin, T., and Wood, F. Auto-encoding sequential monte carlo. In
International Con-ference on Learning Representations , 2018.Lee, Y. and Choi, S. Gradient-based meta-learning withlearned layerwise metric and subspace. arXiv preprintarXiv:1801.05558 , 2018.Lefkimmiatis, S. Universal denoising networks: a novelcnn architecture for image denoising. In
Proceedingsof the IEEE conference on computer vision and patternrecognition , pp. 3204–3213, 2018.Li, K. and Malik, J. Learning to optimize. arXiv preprintarXiv:1606.01885 , 2016.Li, Z., Zhou, F., Chen, F., and Li, H. Meta-sgd: Learningto learn quickly for few-shot learning. arXiv preprintarXiv:1707.09835 , 2017.Liu, J., Chen, X., Wang, Z., and Yin, W. ALISTA: Analyticweights are as good as learned weights in LISTA. In
International Conference on Learning Representations ,2019. URL https://openreview.net/forum?id=B1lnzn0ctQ .Martin, D., Fowlkes, C., Tal, D., and Malik, J. A databaseof human segmented natural images and its applicationto evaluating segmentation algorithms and measuringecological statistics. In
Proceedings Eighth IEEE Inter-national Conference on Computer Vision. ICCV 2001 ,volume 2, pp. 416–423. IEEE, 2001.Metzler, C., Mousavi, A., and Baraniuk, R. Learned d-amp: Principled neural network based compressive imagerecovery. In
Advances in Neural Information ProcessingSystems , pp. 1772–1783, 2017.Na, D., Lee, H. B., Lee, H., Kim, S., Park, M., Yang, E.,and Hwang, S. J. Learning to balance: Bayesian meta-learning for imbalanced and out-of-distribution tasks. In
International Conference on Learning Representations ,2020. URL https://openreview.net/forum?id=rkeZIJBYvr . Nowozin, S., Cseke, B., and Tomioka, R. f-gan: Traininggenerative neural samplers using variational divergenceminimization. In
Advances in neural information process-ing systems , pp. 271–279, 2016.Oreshkin, B., L´opez, P. R., and Lacoste, A. Tadam: Task de-pendent adaptive metric for improved few-shot learning.In
Advances in Neural Information Processing Systems ,pp. 721–731, 2018.Pham, H. Optimal stopping of controlled jump diffusionprocesses: a viscosity solution approach. In
Journal ofMathematical Systems, Estimation and Control . Citeseer,1998.Qiao, S., Liu, C., Shen, W., and Yuille, A. L. Few-shot im-age recognition by predicting parameters from activations.In
Proceedings of the IEEE Conference on Computer Vi-sion and Pattern Recognition , pp. 7229–7238, 2018.Ravi, S. and Larochelle, H. Optimization as a model forfew-shot learning. 2017.Shiryaev, A. N.
Optimal stopping rules , volume 8. SpringerScience & Business Media, 2007.Shrivastava, H., Chen, X., Chen, B., Lan, G., Aluru, S., Liu,H., and Song, L. GLAD: Learning sparse graph recovery.In
International Conference on Learning Representations ,2020. URL https://openreview.net/forum?id=BkxpMTEtPB .Simonyan, K. and Zisserman, A. Very deep convolu-tional networks for large-scale image recognition. arXivpreprint arXiv:1409.1556 , 2014.Sun, J., Li, H., Xu, Z., et al. Deep admm-net for com-pressive sensing mri. In
Advances in neural informationprocessing systems , pp. 10–18, 2016.Teerapittayanon, S., McDanel, B., and Kung, H.-T.Branchynet: Fast inference via early exiting from deepneural networks. In , pp. 2464–2469. IEEE,2016.Yakar, T. B., Litman, R., Sprechmann, P., Bronstein, A. M.,and Sapiro, G. Bilevel sparse models for polyphonicmusic transcription. In
ISMIR , pp. 65–70, 2013.Zamir, A. R., Wu, T.-L., Sun, L., Shen, W. B., Shi, B. E.,Malik, J., and Savarese, S. Feedback networks. In
Pro-ceedings of the IEEE Conference on Computer Visionand Pattern Recognition , pp. 1308–1317, 2017.Zhang, J. and Ghanem, B. Ista-net: Interpretableoptimization-inspired deep network for image compres-sive sensing. In
Proceedings of the IEEE Conference earning to Stop While Learning to Predict on Computer Vision and Pattern Recognition , pp. 1828–1837, 2018.Zhang, K., Zuo, W., Chen, Y., Meng, D., and Zhang, L.Beyond a gaussian denoiser: Residual learning of deepcnn for image denoising.
IEEE Transactions on ImageProcessing , 26(7):3142–3155, 2017.Zhang, X., Lu, Y., Liu, J., and Dong, B. Dynamically un-folding recurrent restorer: A moving endpoint controlmethod for image restoration. In
International Confer-ence on Learning Representations , 2019. URL https://openreview.net/forum?id=SJfZKiC5FX . earning to Stop While Learning to Predict A. Derivations
A.1. Proof of Lemma 1
Proof.
Under the assumptions that (cid:96) ( y , x t ; θ ) := − log p θ ( y | t, x ); and the prior p ( t | x ) is a uniform distribution over t , the β -VAE objective can be written as J β -VAE ( θ, q φ ; x , y ) := E q φ log p θ ( y | t, x ) − β KL ( q φ ( t ) || p ( t | x ))= − E q φ (cid:96) ( y , x t ; θ ) − β E q φ ( t ) log q φ ( t ) p ( t | x )= − E q φ (cid:96) ( y , x t ; θ ) − β E q φ ( t ) log q φ ( t )+ β E q φ ( t ) log p ( t | x )= − (cid:0) E q φ (cid:96) ( y , x t ; θ ) − βH ( q φ ) (cid:1) + β E q φ ( t ) log 1 T = − L ( θ, q φ ; x , y ) − β log T. Since the second term − β log T is a constant, maximizing J β -VAE ( θ, q φ ; x , y ) is equivalent to minimizing L ( θ, q φ ; x , y ) . A.2. Equivalence of reverse KL and maximum-entropy RL
The variational distribution q φ actually depends on the input instance x . For notation simplicity, we only write q φ ( t ) insteadof q φ ( t | x ) . min φ KL ( q φ ( t ) || q ∗ θ ( t | y , x )) (12) = min φ − T (cid:88) t =1 q φ ( t ) log q ∗ θ ( t | y , x )) − H ( q φ ) (13) = min φ − T (cid:88) t =1 q φ ( t ) log p θ ( y | t, x ) β ) − H ( q φ ) (14) + T (cid:88) t =1 q φ ( t ) log T (cid:88) τ =1 p θ ( y | τ, x ) β (15) = min φ − T (cid:88) t =1 q φ ( t ) log p θ ( y | t, x ) β − H ( q φ ) (16) + T (cid:88) t =1 q φ ( t ) C ( x , y ) (17) = min φ − T (cid:88) t =1 q φ ( t ) log p θ ( y | t, x ) β − H ( q φ ) (18) + C ( x , y ) (19) = min φ − T (cid:88) t =1 q φ ( t ) log p θ ( y | t, x ) β − H ( q φ ) (20) = max φ T (cid:88) t =1 q φ ( t ) log p θ ( y | t, x ) β + H ( q φ ) (21) = max φ T (cid:88) t =1 q φ ( t ) β(cid:96) ( y , x t ; θ ) + H ( q φ ) (22) earning to Stop While Learning to Predict = max φ E t ∼ q φ [ − β(cid:96) ( y , x t ; θ ) − log q φ ( t )] (23)Define the action as a t ∼ π t = π φ ( x , x t ) , the reward function as r ( x t , a t ; y ) := (cid:40) − β(cid:96) ( y , x t ; θ ) if a t = 1 (i.e. stop) , if a t = 0 (i.e. continue) , and the transition probability as P ( x t +1 | x t , a t ) = (cid:40) if x t +1 = F θ ( x t ) and a t = 0 , else . Then the above optimization can be written as max φ E t ∼ q φ [ − β(cid:96) ( y , x t ; θ ) − log q φ ( t )] (24) = max φ E π φ T (cid:88) t =1 r ( x t , a t ; y ) − log π t ( a t | x , x t ) (25) = max φ E π φ T (cid:88) t =1 [ r ( x t , a t ; y ) + H ( π t )] . (26) B. Experiment Details
B.1. Learning To Learn: Sparse RecoverySynthetic data.
We follow Chen et al. (2018) to choose m = 250 , n = 500 , sample the entries of A i.i.d. from the standardGaussian distribution, i.e., A ij ∼ N (0 , m ) , and then normalize its columns to have the unit (cid:96) norm. To generate y ∗ , wedecide each of its entry to be non-zero following the Bernoulli distribution with p b = 0 . . The values of the non-zero entriesare sampled from the standard Gaussian distribution. The noise (cid:15) is Gaussian white noise. The signal-to-noise ratio (SNR)for each sample is uniformly sampled from 20, 30 and 40. For the testing phase, a test set of 3000 samples are generated,where there are 1000 samples for each noise level. This test set is fixed for all experiments in our simulations. Evaluation metric.
The performance is evaluated by NMSE (in dB), which is defined as
10 log ( (cid:80) Ni =1 (cid:107) ˆ x i − x ∗ ,i (cid:107) (cid:80) Ni =1 (cid:107) x ∗ ,i (cid:107) ) where ˆ x i is the estimator returned by an algorithm or deep model. B.2. Task-imbalanced Meta Learning
B.2.1. D
ETAILS OF SETUP
Hyperparameters
We train MAML with batch size 16 on Omniglot imbalanced and batch size 2 on MiniImagenetimbalanced datasets. In both scenario we train with 60000 of mini-batch updates for the outer-loop of MAML. We report theresults with 5 inner SGD steps for Omniglot imbalanced and 10 inner SGD steps for MiniImagenet imbalanced with otherbest hyperparameters suggested in (Finn et al., 2017), respectively. For MAML-stop we run 10 inner SGD steps for bothdatasets, with the inner learning rate to be . and . for Omniglot and MiniImagenet, respectively. The outer learningrate for MAML-stop is e − as we use batch size 1 for training.When generating each meta-training dataset, we randomly select the number of observations within k to k for k - k -shotlearning. The number of observations in test set is always kept the same within each round of experiment.B.2.2. M EMORY EFFICIENT IMPLEMENTATION
As our MAML-stop allows the automated decision of optimal stopping, it is preferable that the maximum number of SGDupdates per each task is set to a larger number to fully utilize the capacity of the approach. This brings the challenge duringtraining, as the loss on each meta-test set during training is required for each single inner update step. That is to say, if weallow maximumly 10 steps of inner SGD update, then the memory cost for running CNN prediction on meta-test set is 10xlarger than vanilla MAML. Thus a straightforward implementation will not give us a feasible training mechanism. earning to Stop While Learning to Predict
To make the training of MAML-stop feasible on a single GPU, we utilize the following techniques: • We use stochastic EM for learning the predictive model, as well as the stopping policy. Specifically, we sample t ∼ q ∗ θ ( ·| y , x ) in each round of training, and only maximize p θ ( y | t, x ) in this round. • As the auto differentiation in PyTorch is unable to distinguish between ‘no gradient’ and ‘zero gradient’, it causes extrastorage for the unnecessary gradient computation. To overcome this, we first calculate q ∗ θ ( t | y , x ) for each t without anygradient storage (which corresponds to no grad() in PyTorch), then recompute p θ ( y | t, x ) for the sampled t .With the above techniques, we can train MAML-stop almost as (memory) efficient as MAML.B.2.3. S TANDARD META - LEARNING TASKS
For completeness, we also include the MAML-stop in the standard setting of few-shot learning. We mainly compared withthe vanilla MAML for the sake of ablation study.
Hyperparameters
The hyperparameter setup mainly follows the vanilla MAML paper. For both MAML and MAML-stop, we use the same batch size, number of training epochs and the learning rate. For Omniglot 20-way experiments andMiniImagenet 5-way experiments, we tune the number of unrolling steps in { , , . . . , } , β in { , . , . , . } andthe learning rate of inner update in { . , . } . We simply use grid search with a random held-out set with 600 tasks toselect the best model configuration. B.3. Image Denoising
B.3.1. I
MPLEMENTATION D ETAILS
When training the denoising models, the raw images were cropped and augmented into 403K ∗ patchs. The trainingbatch size was . We used Adam optimizer with the initial learning rate as e − . We first trained the deep learningmodel with the unweighted loss for epochs. Then, we further train the model with the weighted loss for another epoches. After hyper-parameter searching, we set the exploration coefficient β as 0.1. When training the policy network, weused the Adam optimizer with the learning rate as e − . We reused the above hyper-parameters during joint training.B.3.2. V ISUALIZATION
Ground Truth Noisy Image BM3DWNNM DnCNN DnCNN-stop
Figure 6.
Denoising results of an image with noise level 65. earning to Stop While Learning to Predict
Ground Truth Noisy Image BM3DWNNM DnCNN DnCNN-stop
Figure 7.
Denoising results of an image with noise level 65.