Few-Shot Learning with Embedded Class Models and Shot-Free Meta Training
FFew-Shot Learning with Embedded Class Models and Shot-Free Meta Training
Avinash RavichandranAmazon Web Services [email protected]
Rahul BhotikaAmazon Web Services [email protected]
Stefano SoattoAmazon Web Services and UCLA [email protected]
Abstract
We propose a method for learning embeddings for few-shot learning that is suitable for use with any number ofshots (shot-free). Rather than fixing the class prototypes tobe the Euclidean average of sample embeddings, we allowthem to live in a higher-dimensional space (embedded classmodels) and learn the prototypes along with the model pa-rameters. The class representation function is defined im-plicitly, which allows us to deal with a variable number ofshots per class with a simple constant-size architecture. Theclass embedding encompasses metric learning, that facili-tates adding new classes without crowding the class repre-sentation space. Despite being general and not tuned to thebenchmark, our approach achieves state-of-the-art perfor-mance on the standard few-shot benchmark datasets.
Figure 1. One image of a mushroom (
Muscaria ) may be enoughto recognize it in the wild (left); in other cases, there may be moresubtle differences between an edible (
Russula , shown in the cen-ter) and a deadly one (
Phalloides , shown on the right), but still fewsamples are enough for humans.
1. Introduction
Consider Figure 1: Given one or few images of an
Amanita Muscaria (left), one can easily recognize it in thewild. Identifying a
Russula (center) may require more sam-ples, enough to distinguish it from the deadly
Amanita Phal-loides (right), but likely not millions of them. We refer tothis as few-shot learning.
This ability comes from havingseen and touched millions of other objects, in different en-vironments, under different lighting conditions, partial oc- clusions and other nuisances. We refer to this as meta-learning.
We wish to exploit the availability of large an-notated datasets to meta-train models so they can learn newconcepts from few samples, or “shots.” We refer to this as meta-training for few-shot learning.
In this paper we develop a framework for both meta-training (learning a potentially large number of classes froma large annotated dataset) and few-shot learning (using thelearned model to train new concepts from few samples), de-signed to have the following characteristics.
Open set:
Accommodate an unknown, growing, and pos-sibly unbounded number of new classes in an “open set”or “open universe” setting. Some of the simpler methodsavailable in the literature, for instance based on nearest-neighbors of fixed embeddings [15], do so in theory. Inthese methods, however, there is no actual few-shot learn-ing per se, as all learnable parameters are set at meta-training.
Continual:
Enable leveraging few-shot data to improve themodel parameters, even those inferred during meta-training.While each class may only have few samples, as the numberof classes grows, the few-shot training set may grow large.We want a model flexible enough to enable “lifelong” or “continual” learning. Shot Free:
Accommodate a variable number of shots foreach new category. Some classes may have a few samples,others a few hundred; we do not want to meta-train differ-ent models for different number of shots, nor to restrict our-selves to all new classes having the same number of shots,as many recent works do. This may be a side-effect of thebenchmarks available that only test a few combinations ofshots and “ways” (classes).
Embedded Class Models:
Learn a representation of theclasses that is not constrained to live in the same space as therepresentation of the data. All known methods for few-shotlearning choose an explicit function to compute class rep-resentatives (a.k.a. “prototypes” [15], “proxies,” “means,”“modes,” or “templates”) as some form of averaging in theembedding (feature) space of the data. By decoupling thedata (feature space) from the classes (class embedding), wefree the latter to live in a richer space, where they can bet-1 a r X i v : . [ c s . L G ] A p r er represent complex distributions, and possibly grow overtime.To this end, our contributions are described as follows: • Shot-free:
A meta-learning model and samplingscheme that is suitable for use with any number ofways and any number of shots, and can operate inan open-universe, life-long setting. When we fix theshots, as done in the benchmarks, we achieve essen-tially state-of-the-art performance, but with a modelthat is far more flexible. • Embedded Identities:
We abstract the identities to adifferent space than the features, thus enabling captur-ing more complex classes. • Implicit Class Representation:
The class represen-tation function has a variable number of arguments,the shots in the class. Rather than fixing the numberof shots, or choosing a complex architecture to han-dle variable numbers, we show that learning an im-plicit form of the class function enables seamless meta-training, while requiring a relatively simple optimiza-tion problem to be solved at few-shot time. We do notuse either recurrent architectures that impose artificialordering, or complex set-functions. • Metric Learning is incorporated in our model, en-abling us to add new classes without crowding the classrepresentation space. • Performance:
Since there is no benchmark to show-case all the features of our model, we use existingbenchmarks for few-shot learning that fix the num-ber of ways and shots to a few samples. Some of thetop performing methods are tailored to the benchmark,training different models for different number of shots,which does not scale, and does not enable handlingthe standard case where each way comes with its ownnumber of shots. Our approach, while not tuned to anybenchmark, achieves state-of-the-art performance andis more general.In the next section we present a formalism for ordinary clas-sification that, while somewhat pedantic, allows us to gen-eralize to life-long, open universe, meta- and few-shot train-ing. The general model allows us to analyze existing workunder a common language, and highlights limitations thatmotivate our proposed solution in Sect. 2.3.
In ordinary classification, we call B = { ( x i , y i ) } Mi =1 ,with y i ∈ { , . . . , B } a “large-scale” training set; ( x j , y j ) ∼ P ( x, y ) a sample from the same distribution.If it is in the training set, we write formally P ( y = k | x i ) = δ ( k − y i ) . Outside the training set, we approximate thisprobability with P w ( y = k | x ) := exp( − φ w ( x ) k ) (cid:80) k exp( − φ w ( x ) k ) (1)where the discriminant φ w : X → R K is an element of asufficiently rich parametric class of functions with parame-ters, or “weights,” w , and the subscript k indicates the k -thcomponent. The empirical cross-entropy loss is defined as L ( w ) := K (cid:88) k =1 ( x i ,y i ) ∈B − P ( y = k | x i ) log P w ( y = k | x i )= (cid:88) ( x i ,y i ) ∈B − log P w ( y i | x i ) (2)minimizing which is equivalent to maximizing (cid:81) i P w ( y i | x i ) . If B is i.i.d., this yields the maximum-likelihood estimate ˆ w , that depends on the dataset B and approximates φ ˆ w ( x ) y (cid:39) log P ( y | x ) . We writecross-entropy explicitly as a function of the discriminant as L ( w ) = (cid:88) ( x i ,y i ) ∈B (cid:96) ( φ w ( x i ) y i ) (3)by substituting (1) into (2), where (cid:96) is given, with a slightabuse of notation, by (cid:96) ( v i ) := − v i + LSE( v ) (4)with the log-sum-exp LSE( v ) := log (cid:16)(cid:80) Kk =1 exp( v k ) (cid:17) .Next, we introduce the general form for few-shot and life-long learning, used later to taxonomize modeling choicesmade by different approaches in the literature. Let F = { ( x j , y j ) } N ( k ) j =1 be the few-shot training set,with k ∈ N the classes, or “ways,” and N ( k ) the “shots,”or samples per class. We assume that meta- and few-shotdata x i , x j live in the same domain ( e.g. , natural images),while the meta- and few-shot classes are disjoint, which weindicate with y ∈ B + { , . . . , K } . During meta-training, from the dataset B we learn a para-metric representation (feature, or embedding) of the data φ w ( x ) , for use later for few-shot training. During few-shot training, we use N ( k ) samples for each new category The number of ways K is a-priori unknown and potentially un-bounded. It typically ranges from a few to few hundreds, while N ( k ) is anywhere from one to a few thousands. The meta-training set has typi-cally M in the millions and B in the thousands. Most benchmarks assumethe same number of shots for each way, so there is a single number N ,an artificial and unnecessary restriction. There is no loss of generality inassuming the classes are disjoint, as few-shot classes that are shared withthe meta-training set can just be incorporated into the latter. > B to train a classifier, with k potentially growing un-bounded (life-long learning). First, we define “useful” andthen formalize a criterion to learn the parameters w , bothduring meta- and few-shot training.Unlike standard classification, discussed in the previoussection, here we do not know the number of classes ahead oftime, so we need a representation that is more general thana K -dimensional vector φ w . To this end, consider two addi-tional ingredients: A representation of the classes c k (iden-tities, prototypes, proxies), and a mechanism to associatea datum x j to a class k through its representative c k . Wetherefore have three functions, all in principle learnable andtherefore indexed by parameters w . The data representation φ w : X → R F maps each datum to a fixed-dimensionalvector, possibly normalized, z = φ w ( x ) . (5)We also need a class representation , that maps the N ( k ) features z j sharing the same identity y j = k , to some rep-resentative c k through a function ψ w : R F N ( k ) → R C thatyields, for each k = B + 1 , . . . , B + Kc k = ψ w ( { z j | y j = k } ) (6)where z j = φ w ( x j ) . Note that the argument of ψ has vari-able dimension. Finally, the class membership can be de-cided based on the posterior probability of a datum belong-ing to a class, approximated with a sufficiently rich para-metric function class in the exponential family as we didfor standard classification, P w ( y = k | x j ) := exp ( − χ w ( z j , c k )) (cid:80) k exp( − χ w ( z j , c k )) (7)where χ w : R F × R C → R is analogous to (1). The cross-entropy loss (2) can then be written as L ( w ) = B + K (cid:88) k = B +1 N ( k ) (cid:88) j =1 (cid:96) ( χ w ( z j , c k )) (8)with (cid:96) given by (4) and c k by (6). The loss is minimizedwhen χ ˆ w ( z j , c k ) = log P ( y j = k | x j ) , a function of thefew-shot set F . Note, however, that this loss can also beapplied to the meta-training set, by changing the outer sumto k = 1 , . . . , B , or or to any combination of the two, by se-lecting subsets of { , . . . , B + K } . Different approaches tofew-shot learning differ in the choice of model M and mix-ture of meta- and few-shot training sets used in one iterationof parameter update, or training “episode.”
2. Stratification of Few-shot Learning Models
Starting from the most general form of few-shot learn-ing described thus far, we restrict the model until there isno few-shot learning left, to capture the modeling choicesmade in the literature.
In general, during meta-training for few-shot learning,one solves some form of ˆ w = arg min w (cid:88) ( x i ,y i ) ∈B (cid:96) ( χ w ( z i , c i )) (cid:124) (cid:123)(cid:122) (cid:125) L ( w,c ) s . t . z i = φ w ( x i ); c i = ψ w ( { z j | y j = i } ) . Implicit class representation function:
Instead of the ex-plicit form in (6), one can infer the function ψ w implicitly:Let r = min w L ( w, ψ w ) be the minimum of the optimiza-tion problem above. If we consider c = { c , . . . , c B } asfree parameters in L ( w, c ) , the equation r = L ( ˆ w, c ) de-fines c implicitly as a function of ˆ w , ψ ˆ w . One can thensimply find ˆ w and c simultaneously by solving ˆ w, ˆ c = arg min w,c B (cid:88) k =1 i | y i = k (cid:96) ( χ w ( φ w ( x i ) , c k )) (9)which is equivalent to the previous problem, even if thereis no explicit functional form for the class representation ψ w . As we will see, this simplifies meta-learning, as thereis no need to design a separate architecture with a variablenumber of inputs ψ w , but requires solving a (simple) opti-mization during few-shot learning. This is unlike all otherknown few-shot learning methods, that learn or fix ψ w dur-ing meta-learning, and keep it fixed henceforth.Far from being a limitation, the implicit solution has sev-eral advantages, including bypassing the need to explicitlydefine a function with a variable number of inputs (or a setfunction) ψ w . It also enables the identity representation tolive in a different space than the data representation, againunlike existing work that assumes a simple functional formsuch as the mean. Lifelong few-shot learning:
Once meta-training is done,one can use the same loss function in (9) for k > B toachieve life-long, few-shot learning. While each new cat-egory k > B is likely to have few samples N ( k ) , in theaggregate the number of samples is bound to grow beyond M , which we can exploit to update both the embedding φ w ,the metric χ w and the class function c k = ψ w . Metric learning:
A simpler model consists of fixing theparameters of the data representation ˆ φ := φ ˆ w and using thesame loss function, but summed for k > B , to learn fromfew shots N k the new class proxies c k and change the metric χ w as the class representation space becomes crowded. Ifwe fix the data representation, during the few-shot traininghase, we solve ˆ w, ˆ c = arg min w,c B + K (cid:88) k = B +1 (cid:88) j | y j = k (cid:96) ( χ w ( ˆ φ ( x j ) , c k )) (10)where the dependency on the meta-training phase is through ˆ φ and both ˆ w and ˆ c depend on the few-shot dataset F . New class identities:
One further simplification step is toalso fix the metric χ , leaving only the class representativesto be estimated ˆ c = arg min c B + K (cid:88) k = B +1 (cid:88) j | y j = k (cid:96) ( χ ( ˆ φ ( x j ) , c k )) . (11)The above is the implicit form of the parametric function ψ w , with parameters w = c , as seen previously. Thus eval-uating ˆ c k = ψ c ( { z j | y j = k } ) requires solving an optimiza-tion problem. No few-shot learning:
Finally, one can fix even the func-tion ψ explicitly, forgoing few-shot learning and simplycomputing ˆ c k = ψ ( { ˆ φ ( x j ) | y j = k } ) , k > B (12)that depends on B through ˆ φ , and on F through Y k .We articulate our modeling and sampling choices in thenext section, after reviewing the most common approachesin the literature in light of the stratification described. Most current approaches fall under the case (12), thusinvolving no few-shot learning, forgoing the possibility oflifelong learning and imposing additional undue limitationsby constraining the prototypes to live in the same spaceof the features. Many are variants of Prototypical Net-works [15], where only one of the three components ofthe model is learned: ψ is fixed to be the mean, so c k := | Y k | (cid:80) j ∈ Y k z j and χ ( z, c ) = (cid:107) z − c (cid:107) is the Euclidean dis-tance. The only learning occurs at meta-training, and thetrainable portion of the model φ w is a conventional neuralnetwork. In addition, the sampling scheme used for trainingmakes the model dependent on the number of shots, againunnecessarily.Other work can be classified into two main categories:gradient based [11, 3, 9, 14] and metric based [15, 20, 10, 4].In the first, a meta-learner is trained to adapt the parametersof a network to match the few-shot training set. [11] usesthe base set to learn long short-term memory (LSTM) units[6] that update the base classifier with the data from the few-shot training set. MAML [3] learns an initialization for thenetwork parameters that can be adapted by gradient descentin a few steps. LEO [14] is similar to MAML, but uses atask specific initial condition and performs the adaptation in a lower-dimensional space. Most of these algorithms adapt φ w ( x ) and use an ordinary classifier at few-shot test time.There is a different φ w ( x ) for every few-shot training set,with little re-use or any continual learning.On the metric learning side, [20] trains a weighted clas-sifier using an attention mechanism [22] that is applied tothe output of a feature embedding trained on an the baseset. This method requires the shots at meta- and few-shottraining to match. Prototypical Networks [15] are trainedwith episodic sampling and a loss function based on theperformance of a nearest-mean classifier [19] applied to afew-shot training set. [4] generates classification weightsfor a novel class based on a feature extractor using the basetraining set. Finally, [1] incorporates ridge regression inan end-to-end manner into a deep-learning network. Thesemethods learn a single φ w ( x ) , which is reused across few-shot training tasks. The class identities are then either ob-tained through a function defined a-priori such as the sam-ple mean in [15], an attention kernel [20], or ridge regres-sion [1]. The form of ψ w or χ do not change at few-shottraining. [10] uses task-specific adaptation networks to fa-cilitate the adapting embedding network with output on atask-dependent metric space. In this method, the form of χ and ψ are fixed and the output of φ is modulated based onthe few-shot training set.Next, we describe our model that, to the best of ourknowledge, is the first and only to learn each componentof the model: The embedding φ w , the metric χ w , and im-plicitly the class representation φ w .
3. Proposed Model
Using the formalism of Sect. 2 we describe our mod-eling choices. Note that there is redundancy in the modelclass M , as one could fix the data representation φ ( x ) = x ,and devolve all modeling capacity to ψ , or vice-versa. Thechoice depends on the application context. We outline ourchoices, motivated by limitations of prior work. Embedding φ w : In line with recent work, we choose a deepconvolutional network. The details of the architecture are inSect. 4. Class representation function ψ w : We define it implicitlyby treating the class representations c k as parameters alongwith the weights w . As we saw earlier, this means thatat few-shot training, we have to solve a simple optimiza-tion problem (11) to find the representatives of new classes,rather than computing the mean as in Prototypical Networksand its variants: c k = arg min c (cid:88) j | y j = k (cid:96) ( χ w ( ˆ φ ( x j ) , c )) = ψ c ( k ) . (13)Note that the class estimates depend on the parameters w in χ . If few-shot learning is resource constrained, one cantill learn the class representations implicitly during meta-training, and approximate them with a fixed function, suchas the mean, during the few-shot phase. Metric χ : we choose a discriminant induced by the Eu-clidean distance in the space of class representations, towhich data representations are mapped by a learnable pa-rameter matrix W : χ W ( z j , c k ) = (cid:107) W ˆ φ ( x j ) − c k (cid:107) (14)Generally, we pick the dimension of c larger than the dimen-sion of z , to enable capturing complex multi-modal identityrepresentations. Note that this choice encompasses metriclearning: If Q = Q T was a symmetric matrix representinga change of inner product, then (cid:107) W φ − c (cid:107) Q = φ T W T Qc would be captured by simply choosing the weights ˜ W = QW . Since both the weights and the class proxies as free,there is no gain in generality in adding the metric param-eters Q . Of course, W can be replaced by any non-linearmap, effectively “growing” the model via χ w ( z j , c k ) = (cid:107) ˆ f w ( φ ( x j )) − c k (cid:107) (15)for some parametric family f w such as a deep neural net-work.
4. Implementation
Embedding φ w ( x j ) We use two different architectures.The first [15, 20] is four-convolution blocks, each blockwith 64 × filters followed by batch-normalization andReLU. This is passed through max-pooling of a × ker-nel. Following the convention in [4], we call this archi-tecture C64. The other network is a modified ResNet [5],similar to [10]. We call this ResNet-12.In addition, we normalize the embedding to live on theunit sphere, i.e . φ ( x ) ∈ S d − , where d is the dimension ofthe embedding. This normalization is added as a layer toensure that the feature embedding are on the unit sphere, asopposed to applying it post-hoc. This adds some complica-tions during meta-training due to poor scaling of gradients[21], and is addressed by a single parameter layer after nor-malization, whose sole purpose is scaling the output of thenormalization layer. This layer is not required at test time. Class representation:
As noted earlier, this is implicitduring meta-training. In order to show the flexibility of ourframework, we increase the dimension of the class repre-sentation.
Metric χ We choose the angular distance in feature space,which is the d -hypersphere: χ ( z j , c k ) = (cid:107) W z j − c k (cid:107) = 2 s (1 − cos θ ) , (16) where s is the scaling factor used during training and θ theangle between the normalized arguments. As the repre-sentation z = φ w ( x ) is normalized, the class-conditionalmodel is a Fisher-Von Mises (spherical Gaussian). How-ever, as W φ w ( x i ) ∈ S d − , we need W ψ w ∈ S d − . Duringmeta-training we apply the same normalization and scalefunction to the implicit representation as well. P w ( y = k | x ) ∝ exp (cid:104) W φ w ( x ) , c k (cid:105) (17)up to the normalization constant. Sampling
At each iteration during meta-training, imagesfrom the training set B are presented to the network in theform of episodes [20, 11, 15]; each episode consists of im-ages sampled from K classes. The images are selected byfirst sampling K classes from B and then sampling N e im-ages from each of the sampled classes. The loss function isnow restricted to the K classes present in the episode as op-posed to the entire set of classes available at meta-training.This setting allows for the network to learn a better embed-ding for an open set classification as shown in [2, 20]Unlike existing sampling methods that use episodic sam-pling [11, 15], we do not split the images within an episodeinto a meta-train set and a meta-test set. For instance, proto-typical networks [15] use the elements in the meta-train setto learn the mean of the class representation. [11] learns theinitial conditions for optimization. This requires a notion oftraining “shot,” and results in multiple networks to matchthe shots one expects at few-shot training. Regularization
First, we notice that the loss function (9)has a degenerate solution where all the centers and the em-beddings are the same. In this case, P w ( y = k | x j ) = P w ( y = k (cid:48) | x j ) for all k and k (cid:48) , i.e ., P w ( y = k (cid:48) | x j ) is auniform distribution. For this degenerate case, the entropyis maximum, so we use entropy to bias the solution awayfrom the trivial one. We also use Dropout [16] on top of theembedding φ w ( x ) during meta-training. Even when usingepisodic sampling, the embedding tends to over-fit on thebase set in the absence of dropout. We do not use this atfew-shot train and test time.Figure 2 summarizes our architecture for the loss func-tion during meta training. This has layers that are onlyneeded for training such as the scale layer, Dropout and theloss. During few-shot training, we only use the learned em-bedding φ w ( x ) .
5. Experimental Results
We test our algorithm on three datasets: miniImagenet[20], tieredImagenet [12] and CIFAR Few-Shot [1]. TheminiImagenet dataset consists of images of size × sampled from 100 classes of the ILSVRC [13] dataset, with R O P O U T w ( x )
600 images per class. We used the data split outlined in[11], where 64 classes are used for training, 16 classes areused for validation, and 20 classes are used for testing.We also use tieredImagenet [12]. This is a larger sub-set of ILSVRC, and consists of 779,165 images of size × representing 608 classes hierarchically grouped into34 high-level classes. The split of this dataset ensures thatsub-classes of the 34 high-level classes are not spread overthe training, validation and testing sets, minimizing the se-mantic overlap between training and test sets. The resultis 448,695 images in 351 classes for training, 124,261 im-ages in 97 classes for validation, and 206,209 images in 160classes for testing. For a fair comparison, we use the sametraining, validation and testing splits as in [12], and use theclasses at the lowest level of the hierarchy.Finally, we use CIFAR Few-Shot, (CIFAR-FS) [1] con-taining images of size × , a reorganized version of theCIFAR-100 [8] dataset. We use the same data split as in [1],dividing the 100 classes into 64 for training, 16 for valida-tion, and 20 for testing. Many recent methods are variants of Prototypical Net-works, so we perform detailed comparison with it. We keepthe training procedure, network architecture, batch-size aswell as data augmentation the same. The performance gainsare therefore solely due to the improvements in our method.We use ADAM [7] for training with an initial learningrate of − , and a decay factor of . every 2,000 iter-ations. We use the validation set to determine the bestmodel. Our data augmentation consists of mean subtrac-tion, standard-deviation normalization, random croppingand random flipping during training. Each episode contains15 query samples per class during training. In all our exper-iments, we set λ = 1 and did not tune this parameter.Except otherwise noted, we always test few-shot algo- rithms on 2000 episodes, with 30 query classes per pointper episode. At few-shot training, we experimented withsetting the class identity to be implicit (optimized) or av-erage prototype (fixed). The latter may be warranted whenthe few-shot phase is resource-constrained and yields simi-lar performance. To compare computation time, we use thefixed mean. Note that, in all cases, the class prototypes arelearned implicitly during meta-training.The results of this comparison are shown in Table 1.From this table we see that for the 5-shot 5-way case weperform similarly to Prototypical Network. However, forthe 1-shot case we see significant improvements across allthree datasets. Also, the performance of Prototypical Net-works drops when the train and test shot are changed. Ta-ble 1 shows a significant drop in performance when we testmodels with a 5-shot setting and train with 1-shot. Noticethat, from the table, our method is able to maintain the sameperformance. Consequently, we only train one model andtest it across the different shot scenarios, hence the moniker“shot-free.” Class identities c k can live in a space of different di-mensions than the feature embedding. This can be done intwo ways: by lifting the embedding into a higher dimensionspace or by projecting the class identity into the embeddingdimension. If the dimension of the class identity changes,we also need to modify χ according to (14). The weightmatrix W ∈ R d × µ , where d is the dimension of the embed-ding and µ is the dimension of the class identities, can belearned during meta-training. This is equivalent to adding afully connected layer through which the class identities arepassed before normalization. Thus, we now learn φ w , ψ k and χ W . We show experimental results with the C archi-tecture on the miniImagenet datasets in Table 2. Here, wetested the dimension of the class identities to be × , × and × the dimension of the embedding. From this table wesee that increasing the dimensions gives us a performanceboost. However, this increase saturates at a dimension of × the dimension of the embedding space. In order to compare with the state-of-the-art, we use theResNet-12 base architecture, train our approach using SGDwith Nesterov momentum with an initial learning rate of . , weight decay of e − , momentum of . and eightepisodes per batch. Our learning rate was decreased bya factor of . every time the validation error did not im-prove for 1000 iterations. We did not tune these parametersbased on the dataset. As mentioned earlier, we train one model and test across various shots. We also compare ourmethod with class identities in a space with twice the di-mension of the embedding. Lastly, we compare our methodataset Testing Scenario Training Scenario Our implementation of [15] Our MethodminiImagenet 1-shot 5-way 1-shot 5-way 43.88 ± ± ± ± ± ± tieredImagenet 1-shot 5-way 1-shot 5-way 41.36 ± ± ± ± ± ± ± ± ± ± ± ± Table 1. Comparison of results from our method to that of our implementation of Prototypical Network [15] using the C64 networkarchitecture. The table shows the accuracy and 95% percentile confidence interval of our method averaged over 2,000 episodes on differentdatasets. Note that our method does not have a notion of shot, here we when we imply training by different shot, we mean that the batchsizes is the same as that of the prescribed method.
Dimension 1x 2x 5x 10xPerformance 49.07 51.46 51.46 51.32
Table 2. Performance of our method on miniImagenet with theclass identity dimension as a function of the embedding dimensionusing the C64 network architecture. The table shows the accuracyaveraged over 2,000 episodes. with a variant of ResNet where we change the filter sizes to(64,160,320,640) from (64,128,256,512).The results of our comparison for miniImagenet is shownin Table 3. Modulo empirical fluctuations, our method per-forms at the state-of-the art and in some cases exceeds it.We wish to point out that SNAIL [9], TADAM [10, 17],LEO [14], MTLF [17] pre-train the network for a 64 wayclassification task on miniImagenet and 351 way classifi-cation on tieredImagenet. However, all the models trainedfor our method are trained from scratch and use no form ofpre-training. We also do not use the meta-validation set fortuning any parameters other than selecting the best trainedmodel using the error on this set. Furthermore, unlike allother methods, we did not have to train multiple networksand tune the training strategy for each case. Lastly, LEO[14] uses a very deep 28 layer Wide-ResNet as a base modelcompared to our shallower ResNet-12. A fair comparisonwould involve training our methods with the same base net-work. However, we include this comparison for completetransparency.The performance of our method on tieredImagenet isshown in Table 4. This table shows that we are the top per-forming method for 1-shot 5-way and 5-shot 5-way. Wetest on this dataset as it is much larger and does not havesemantic overlap between meta training and few-shot train-ing even though fewer baselines exist for this dataset com-pared to miniImagenet. Also shown in Table 4 is the perfor-
Algorithm 1-shot 5-Shot 10-shot5-way 5-way 5-wayMeta LSTM [11] 43.44 60.60 -Matching networks [20] 44.20 57.0 -MAML [3] 48.70 63.1 -Prototypical Networks [15] 49.40 68.2 -Relation Net [18] 50.40 65.3 -R2D2 [1] 51.20 68.2 -SNAIL [9] 55.70 68.9 -Gidaris et al . [4] 55.95 73.00 -TADAM [10] 58.50 76.7 80.8MTFL [17] 61.2 75.5 -LEO [14] 61.76 77.59 -Our Method (ResNet-12) 59.00 77.46 82.33Our Method (ResNet-12) 2x dims. 60.64 77.02 80.80Our Method (ResNet-12 Variant) 59.04
Our Method (ResNet-12 Variant) 2x dims 60.71 77.26 81.34
Table 3. Performance of 4 variants of our method on miniImagenetcompared to the state-of-the-art. The table shows the accuracyaveraged over 2,000 episodes. mance of our method on the CIFAR Few-Shot dataset. Weshow results on this dataset to illustrate that our method cangeneralize across datasets. From this table we see that ourmethod performs the best for CIFAR Few-Shot.
As a final remark, there is no consensus on the few-shottraining and testing paradigm in the literature. There are toomany variables that can affect performance. To illustratethis, we show the effect of few training choices.
Effect of Optimization algorithm
In the original imple-mentation of Prototypical Networks [15], ADAM [7] wasused as the optimization algorithm. However, most neweralgorithms such as [10, 4] use SGD as their optimizationalgorithm. This result of using different optimization al- lgorithm 1-shot 5-Shot 10-shot5-way 5-way 5-waytieredImagenetMAML [3] 51.67 70.30 -Prototypical Networks [12] 53.31 72.69 -Relation Net [18] 54.48 71.32 -LEO [14] 65.71 81.31 -Our Method (ResNet-12) 63.99 81.97 85.89Our Method (ResNet-12) 2x dims.
Our Method (ResNet-12) Variant 2x dims
Table 4. Performance of our method on tieredImagenet and CI-FAR Few-Shot datasets as compared to the state-of-the-art. Theperformance numbers for CIFAR Few-Shot are from [1]. The ta-ble shows the accuracy averaged over 2,000 episodes. Note thatthe training setting for the prior work is different. gorithms is shown in Table 5. Here, we show the perfor-mance of our algorithm on the miniImagenet dataset usinga ResNet-12 model. From this table we see that, while forthe 1-shot 5-way the results are better with ADAM as op-posed to SGD, we see that the same does not hold for the5-shot 5-way and 10-shot 5-way scenarios. This shows thatSGD generalizes better for our algorithm as compared toADAM.
Optimization Algorithm 1-shot 5-Shot 10-shot5-way 5-way 5-wayADAM
Table 5. Performance of our method on miniImagenet using theResNet-12 model with different choices of optimization algorithm.The table shows the accuracy averaged over 2,000 episodes.
Effect of number of tasks per iteration.
TADAM [10]and Gidaris et al . [4] use multiple episodes per iteration.They refer to this as tasks in TADAM [10], which uses 2tasks for 5-shot, 1 task for 10-shot and 5 task for 1-shot.We did not perform any such tuning and instead defaultedit to 8 episodes per iteration based on Gidaris et al . [4]. Wealso experimented with 16 episodes per iteration. However,this led to a loss in performance across all testing scenarios.Table 6, shows the performance numbers on miniImagenetdataset using the ResNet-12 architecture and trained usingADAM [7] as the optimization algorithm. From this tablewe see that for all the scenarios 8 episodes per iteration has a better performance.
Choice 1-shot 5-Shot 10-shot5-way 5-way 5-way8 episodes per iteration
16 episodes per iteration 58.22 74.53 78.61
Table 6. Performance of our method on miniImagenet using aResNet-12 model with different choices of episodes per iteration.The table shows the accuracy averaged over 2,000 episodes.
Even with all major factors such as network architecture,training procedure, batch size remaining the same, factorssuch as the number of query points used for testing thesemethods affect the performance and methods in existing lit-erature uses anywhere between 15-30 points for testing, andfor some methods it is unclear what this choice was. Thiscalls for stricter protocols for evaluation, and richer bench-mark datasets.
6. Discussion
We have presented a method for meta-learning for few-shot learning where all three ingredients of the problem arelearned: The representation of the data φ w , the representa-tion of the classes ψ c , and the metric or membership func-tion χ W . The method has several advantages compared toprior approaches. First, by allowing the class representa-tion and the data representation spaces to be different, wecan allocate more representative power to the class proto-types. Second, by learning the class models implicitly wecan handle a variable number of shots without having toresort to complex architectures, or worse, training differ-ent architectures, one for each number of shots. Finally, bylearning the membership function we implicitly learn themetric, which allows class prototypes to redistribute duringfew-shot learning.While some of these benefits are not immediately evidentdue to limited benchmarks, the improved generality allowsour model to extend to a continual learning setting wherethe number of new classes grows over time, and is flexiblein allowing each new class to come with its own number ofshots. Our model is simpler than some of the top performingones in the benchmarks. A single model performs on-par orbetter in the few-shot setting and offers added generality. References [1] Luca Bertinetto, Jo˜ao F. Henriques, Philip H. S. Torr,and Andrea Vedaldi. Meta-learning with differentiableclosed-form solvers.
CoRR , abs/1805.08136, 2018. 4,5, 6, 7, 8[2] Wei-Yu Chen, Yen-Cheng Liu, Zsolt Kira, Yu-Chiang Frank Wang, and Jia-Bin Huang. A closer lookt few-shot classification. In
International Conferenceon Learning Representations , 2019. 5[3] Chelsea Finn, Pieter Abbeel, and Sergey Levine.Model-agnostic meta-learning for fast adaptation ofdeep networks. In
ICML , 2017. 4, 7, 8[4] Spyros Gidaris and Nikos Komodakis. Dynamic few-shot visual learning without forgetting. In
CVPR ,2018. 4, 5, 7, 8[5] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and JianSun. Deep residual learning for image recognition. In
CVPR , pages 770–778. IEEE Computer Society, 2016.5[6] Sepp Hochreiter and J¨urgen Schmidhuber. Long short-term memory.
Neural Comput. , 9(8):1735–1780, Nov.1997. 4[7] Diederik P. Kingma and Jimmy Lei Ba. ADAM:A method for stochastic optimization.
InternationalConference on Learning Representations 2015 , 2015.6, 7, 8[8] Alex Krizhevsky. Learning multiple layers of fea-tures from tiny images. Technical report, Universityof Toronto, 2009. 6[9] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, andPieter Abbeel. A simple neural attentive meta-learner.In
ICLR , 2018. 4, 7[10] Boris N. Oreshkin, Pau Rodr´ıguez, and Alexandre La-coste. Improved few-shot learning with task condi-tioning and metric scaling. In
NIPS , 2018. 4, 5, 7,8[11] Sachin Ravi and Hugo Larochelle. Optimization as amodel for few-shot learning. In
ICLR , 2017. 4, 5, 7[12] Mengye Ren, Eleni Triantafillou, Sachin Ravi, JakeSnell, Kevin Swersky, Joshua B. Tenenbaum, HugoLarochelle, and Richard S. Zemel. Meta-learningfor semi-supervised few-shot classification.
CoRR ,abs/1803.00676, 2018. 5, 6, 8[13] Olga Russakovsky, Jia Deng, Hao Su, JonathanKrause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang,Andrej Karpathy, Aditya Khosla, Michael Bernstein,Alexander C. Berg, and Li Fei-Fei. Imagenet largescale visual recognition challenge.
Int. J. Comput. Vi-sion , 115(3):211–252, Dec. 2015. 5 [14] Andrei A. Rusu, Dushyant Rao, Jakub Sygnowski,Oriol Vinyals, Razvan Pascanu, Simon Osindero, andRaia Hadsell. Meta-learning with latent embeddingoptimization.
CoRR , abs/1807.05960, 2018. 4, 7, 8[15] Jake Snell, Kevin Swersky, and Richard S. Zemel.Prototypical networks for few-shot learning. In
NIPS ,pages 4080–4090, 2017. 1, 4, 5, 7, 8[16] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky,Ilya Sutskever, and Ruslan Salakhutdinov. Dropout:A simple way to prevent neural networks from over-fitting.
J. Mach. Learn. Res. , 15(1):1929–1958, Jan.2014. 5[17] Qianru Sun, Yaoyao Liu, Tat-Seng Chua, and BerntSchiele. Meta-transfer learning for few-shot learning.
CoRR , abs/1812.02391, 2018. 7[18] Flood Sung, Yongxin Yang, Li Zhang, Tao Xiang,Philip H.S. Torr, and Timothy M. Hospedales. Learn-ing to compare: Relation network for few-shot learn-ing. In
The IEEE Conference on Computer Vision andPattern Recognition (CVPR) , June 2018. 7, 8[19] Robert Tibshirani, Trevor Hastie, BalasubramanianNarasimhan, and Gilbert Chu. Diagnosis of multiplecancer types by shrunken centroids of gene expres-sion.
Proceedings of the National Academy of Sci-ences , 99(10):6567–6572, 2002. 4[20] Oriol Vinyals, Charles Blundell, Timothy Lillicrap,Koray Kavukcuoglu, and Daan Wierstra. Matchingnetworks for one shot learning. In
NIPS , 2016. 4, 5, 7[21] Feng Wang, Xiang Xiang, Jian Cheng, and Alan Lod-don Yuille. Normface: L2 hypersphere embedding forface verification. In
Proceedings of the 25th ACMInternational Conference on Multimedia , MM ’17,pages 1041–1049, New York, NY, USA, 2017. ACM.5[22] Kelvin Xu, Jimmy Lei Ba, Ryan Kiros, KyunghyunCho, Aaron Courville, Ruslan Salakhutdinov,Richard S. Zemel, and Yoshua Bengio. Show, attendand tell: Neural image caption generation with visualattention. In