Learning Disentangled Representations with Semi-Supervised Deep Generative Models
N. Siddharth, Brooks Paige, Jan-Willem van de Meent, Alban Desmaison, Noah D. Goodman, Pushmeet Kohli, Frank Wood, Philip H.S. Torr
LLearning Disentangled Representations withSemi-Supervised Deep Generative Models
N. Siddharth
University of Oxford [email protected]
Brooks Paige
Alan Turing InstituteUniversity of Cambridge [email protected]
Jan-Willem van de Meent
Northeastern University [email protected]
Alban Desmaison
University of Oxford [email protected]
Noah D. Goodman
Stanford University [email protected]
Pushmeet Kohli ∗ Deepmind [email protected]
Frank Wood
University of Oxford [email protected]
Philip H.S. Torr
University of Oxford [email protected]
Abstract
Variational autoencoders (VAEs) learn representations of data by jointly training a probabilisticencoder and decoder network. Typically these models encode all features of the data into asingle variable. Here we are interested in learning disentangled representations that encodedistinct aspects of the data into separate variables. We propose to learn such representationsusing model architectures that generalise from standard VAEs, employing a general graphicalmodel structure in the encoder and decoder. This allows us to train partially-specified modelsthat make relatively strong assumptions about a subset of interpretable variables and rely onthe flexibility of neural networks to learn representations for the remaining variables. Wefurther define a general objective for semi-supervised learning in this model class, which can beapproximated using an importance sampling procedure. We evaluate our framework’s abilityto learn disentangled representations, both by qualitative exploration of its generative capacity,and quantitative evaluation of its discriminative ability on a variety of models and datasets.
Learning representations from data is one of the fundamental challenges in machine learning andartificial intelligence. Characteristics of learned representations can depend on their intended use.For the purposes of solving a single task, the primary characteristic required is suitability for thattask. However, learning separate representations for each and every such task involves a large amountof wasteful repetitive effort. A representation that has some factorisable structure, and consistentsemantics associated to different parts, is more likely to generalise to a new task.Probabilistic generative models provide a general framework for learning representations: a model isspecified by a joint probability distribution both over the data and over latent random variables, and arepresentation can be found by considering the posterior on latent variables given specific data. Thelearned representation — that is, inferred values of latent variables — depends then not just on thedata, but also on the generative model in its choice of latent variables and the relationships betweenthe latent variables and the data. There are two extremes of approaches to constructing generativemodels. At one end are fully-specified probabilistic graphical models [18, 21], in which a practitionerdecides on all latent variables present in the joint distribution, the relationships between them, andthe functional form of the conditional distributions which define the model. At the other end are ∗ Author was at Microsoft Research during this project.31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA. a r X i v : . [ s t a t . M L ] N ov eep generative models [7, 16, 19, 20], which impose very few assumptions on the structure of themodel, instead employing neural networks as flexible function approximators that can be used totrain a conditional distribution on the data, rather than specify it by hand.The tradeoffs are clear. In an explicitly constructed graphical model, the structure and form of thejoint distribution ensures that latent variables will have particular semantics, yielding a disentangled representation. Unfortunately, defining a good probabilistic model is hard: in complex perceptualdomains such as vision, extensive feature engineering (e.g. Berant et al. [1], Siddharth et al. [30]) maybe necessary to define a suitable likelihood function. Deep generative models completely sidestepthe difficulties of feature engineering. Although they address learning representations which thenenable them to better reconstruct data, the representations themselves do not always exhibit consistentmeaning along axes of variation: they produce entangled representations. While such approacheshave considerable merit, particularly when faced with the absence of any side information about data,there are often situations when aspects of variation in data can be, or are desired to be characterised.Bridging this gap is challenging. One way to enforce a disentangled representation is to hold differentaxes of variation fixed during training [20]. Johnson et al. [13] combine a neural net likelihoodwith a conjugate exponential family model for the latent variables. In this class of models, efficientmarginalisation over the latent variables can be performed by learning a projection onto the sameconjugate exponential family in the encoder. Here we propose a more general class of partially-specified graphical models: probabilistic graphical models in which the modeller only needs specifythe exact relationship for some subset of the random variables in the model. Factors left undefined inthe model definition are then learned, parametrised by flexible neural networks. This provides theability to situate oneself at a particular point on a spectrum , by specifying precisely those axes ofvariations (and their dependencies) we have information about or would like to extract, and learningdisentangled representations for them, while leaving the rest to be learned in an entangled manner.A subclass of partially-specified models that is particularly common is that where we can obtainsupervision data for some subset of the variables. In practice, there is often variation in the datawhich is (at least conceptually) easy to explain, and therefore annotate, whereas other variation is lessclear. For example, consider the MNIST dataset of handwritten digits: the images vary both in termsof content (which digit is present), and style (how the digit is written), as is visible in the right-handside of Fig. 1. Having an explicit “digit” latent variable captures a meaningful and consistent axis ofvariation, independent of style; using a partially-specified graphical model means we can define a“digit” variable even while leaving unspecified the semantics of the different styles, and the processof rendering a digit to an image. In a fully unsupervised learning procedure there is generally noguarantee that inference on a model with 10 classes will in fact recover the 10 digits. However, givena small amount of labelled examples, this task becomes significantly easier. Beyond the ability toencode variation along some particular axes, we may also want to interpret the same data in differentways. For example, when considering images of people’s faces, we might wish to capture the person’sidentity in one context, and the lighting conditions on the faces in another.In this paper we introduce a recipe for learning and inference in partially-specified models, a flexibleframework that learns disentangled representations of data by using graphical model structures toencode constraints to interpret the data. We present this framework in the context of variationalautoencoders (VAEs), developing a generalised formulation of semi-supervised learning with DGMsthat enables our framework to automatically employ the correct factorisation of the objective forany given choice of model and set of latents taken to be observed. In this respect our work extendsprevious efforts to introduce supervision into variational autoencoders [17, 23, 31]. We introduce avariational objective which is applicable to a more general class of models, allowing us to considergraphical-model structures with arbitrary dependencies between latents, continuous-domain latents,and those with dynamically changing dependencies. We provide a characterisation of how to compilepartially-supervised generative models into stochastic computation graphs, suitable for end-to-endtraining. This approach allows us also amortise inference [6, 22, 28, 33], simultaneously learninga network that performs approximate inference over representations at the same time we learn theunknown factors of the model itself. We demonstrate the efficacy of our framework on a variety oftasks, involving classification, regression, and predictive synthesis, including its ability to encodelatents of variable dimensionality. 2 ( h a nd w r i t i n g s t y l e ) y (digit label) Disentangled RepresentationStochastic Computation Graph for VAE ε zpq λη φ n θ zpq λη φ n θ (a) (b) (c) (d)Figure 2: (a) Visual analogies for the MNIST data, with inferred style latent variable fixed andthe label varied. (b) Exploration in “style” space for a 2D latent gaussian random variable. Visualanalogies for the SVHN data when (c) fully supervised, and (d) partially supervised with just 100labels/digit.To train deep generative models in a semi-supervised manner, we need to incorporate labelled data into the variational bound. In a fully unsupervised setting, the contribution of a particular data point x i to the ELBO can be expressed, with minor adjustments of Equation (1), whose Monte-Carlo approximation samples latents z and y from the recognition distribution q z , y | x i . L ✓, ; x i = E q ( z , y | x i ) " log p ✓ x i | z , y p ( z , y ) q ( z , y | x i ) . (2)By contrast, in the fully supervised setting the values y are treated as observed and become fixed inputs into the computation graph, instead of being sampled from q . When the label y is observed along with the data, for fixed ( x i , y i ) pairs, the lower bound on the conditional log-marginal likelihood log p ✓ ( x | y ) is L x | y ✓, z ; x i , y i = E q z ( z | x i , y i ) " log p ✓ x i | z , y i p z | y i q z ( z | x i , y i ) . (3)This quantity can be optimized directly to learn model parameters ✓ and z simultaneously via SGD. However, it does not contain the encoder parameters y . This difficulty was also encountered in a related context by Kingma et al. [17]. Their solution was to augment the loss function by including an explicit additional term for learning a classifier directly on the supervised points. Here we propose an alternative approach. We extend the model with an auxiliary variable ˜ y with likelihood p (˜ y | y ) = ˜ y ( y ) to define densities p (˜ y , y , z , x ) = p (˜ y | y ) p ✓ ( x | y , z ) p ( y , z ) q (˜ y , y , z | x ) = p (˜ y | y ) q ( y , z | x ) . When we marginalize the ELBO for this model over ˜ y , we recover the expression in Equation (2). Treating ˜ y = y i as observed results in the supervised objective L ✓, ; x i ˜ y = y i = E q ( z , y | x i ) " y i ( y ) log p ✓ x i | z , y p ( z , y ) q ( z , y | x i ) . (4)Integration over an observed y is then replaced with evaluation of the ELBO and the density q y at y i . A Monte Carlo estimator of Equation (4) can be constructed automatically for any factorization of q by sampling latent variables z and weighting the resulting ELBO estimate by the conditional density terms q y ( y |· ) . Note that the exact functional form of the Monte Carlo estimator will vary depending on the dependency structure of q z , y | x i . For example, for discrete y , choosing q ( z , y | x ) = q z ( z | y , x ) q y ( y | x ) , decomposes the problem into simultaneously learning a classifier q y ( y | x ) alongside the generative model parameters ✓ and encoder q z ( z | x , y ) which is condi- tioned on the selected class. The computation graph for a model with this factorization is shown in Fig- ure 1. In it, the value y of the distribution q y ( · | x ) is observed, while the distribution q z ( · | x , y ) p λ y ηε z p λ q η φφ x (data) y (partial labels) p η θ q Recognition Model z xy
Generative Model z x ε y x Figure 1: Semi-supervised learning in structured variational autoencoders, illustrated on MNISTdigits.
Top-Left : Generative model.
Bottom-Left:
Recognition model.
Middle : Stochastic com-putation graph, showing expansion of each node to its corresponding sub-graph. Generative-modeldependencies are shown in blue and recognition-model dependencies are shown in orange. SeeSection 2.2 for a detailed explanation.
Right: learned representation.
VAEs [16, 27] are a class of deep generative models that simultaneously train both a probabilisticencoder and decoder for a elements of a data set D = { x , . . . x N } . The central analogy is thatan encoding z can be considered a latent variable, casting the decoder as a conditional probabilitydensity p θ ( x | z ) . The parameters η θ ( z ) of this distribution are the output of a deterministic neuralnetwork with parameters θ (most commonly MLPs or CNNs) which takes z as input. By placing aweak prior over z , the decoder defines a posterior and joint distribution p θ ( z | x ) ∝ p θ ( x | z ) p ( z ) . x n z n θφ N Inference in VAEs can be performed using a variational method that approximates theposterior distribution p θ ( z | x ) using an encoder q φ ( z | x ) , whose parameters λ φ ( x ) arethe output of a network (with parameters φ ) that is referred to as an “inference network”or a “recognition network”. The generative and inference networks, denoted by solidand dashed lines respectively in the graphical model, are trained jointly by performingstochastic gradient ascent on the evidence lower bound (ELBO) L ( φ, θ ; D ) ≤ log p θ ( D ) , L ( φ, θ ; D ) = N (cid:88) n =1 L ( φ, θ ; x n ) = N (cid:88) n =1 E q φ ( z | x n ) [log p θ ( x n | z ) + log p ( z ) − log q φ ( z | x n )] . (1)Typically, the first term E q φ ( z | x n ) [log p θ ( x n | z )] is approximated by a Monte Carlo estimate and theremaining two terms are expressed as a divergence − KL ( q φ ( z | x n ) (cid:107) p ( z )) , which can be computedanalytically when the encoder model and prior are Gaussian.In this paper, we will consider models in which both the generative model p θ ( x , y , z ) and theapproximate posterior q φ ( y , z | x ) can have arbitrary conditional dependency structures involvingrandom variables defined over a number of different distribution types. We are interested in definingVAE architectures in which a subset of variables y are interpretable. For these variables, we assumethat supervision labels are available for some fraction of the data. The VAE will additionally retainsome set of variables z for which inference is performed in a fully unsupervised manner. This is inkeeping with our central goal of defining and learning in partially-specified models. In the runningexample for MNIST, y corresponds to the classification label, whereas z captures all other implicitfeatures, such as the pen type and handwriting style.This class of models is more general than the models in the work by Kingma et al. [17], who considerthree model designs with a specific conditional dependence structure. We also do not require p ( y , z ) to be a conjugate exponential family model, as in the work by Johnson et al. [14]. To performsemi-supervised learning in this class of models, we need to i) define an objective that is suitable togeneral dependency graphs, and ii) define a method for constructing a stochastic computation graph[29] that incorporates both the conditional dependence structure in the generative model and that ofthe recognition model into this objective. 3 .1 Objective Function x n y n z n φθ x m y m z m N M
Previous work on semi-supervised learning for deep generative models [17]defines an objective over N unsupervised data points D = { x , . . . , x N } and M supervised data points D sup = { ( x , y ) , . . . , ( x M , y M ) } , L ( θ, φ ; D , D sup ) = N (cid:88) n =1 L ( θ, φ ; x n ) + γ M (cid:88) m =1 L sup ( θ, φ ; x m , y m ) . (2)Our model’s joint distribution factorises into unsupervised and supervisedcollections of terms over D and D sup as shown in the graphical model. Thestandard variational bound on the joint evidence of all observed data (includ-ing supervision) also factorises as shown in Eq. (2). As the factor corresponding to the unsupervisedpart of the graphical model is exactly that as Eq. (1), we focus on the supervised term in Eq. (2),expanded below, incorporating an additional weighted component as in Kingma et al. [17]. L sup ( θ, φ ; x m , y m ) = E q φ ( z | x m , y m ) (cid:20) log p θ ( x m , y m , z ) q φ ( z | x m , y m ) (cid:21) + α log q φ ( y m | x m ) . (3)Note that the formulation in Eq. (2) introduces an constant γ that controls the relative strength ofthe supervised term. While the joint distribution in our model implicitly weights the two terms, insituations where the relative sizes of D and D sup are vastly different, having control over the relativeweights of the terms can help ameliorate such discrepancies.This definition in Eq. (3) implicitly assumes that we can evaluate the conditional probability q φ ( z | x , y ) and the marginal q φ ( y | x ) = (cid:82) d z q φ ( y , z | x ) . This was indeed the case for the modelsconsidered by Kingma et al. [17], which have a factorisation q φ ( y , z | x ) = q φ ( z | x , y ) q φ ( y | x ) .Here we will derive an estimator for L sup that generalises to models in which q φ ( y , z | x ) can havean arbitrary conditional dependence structure. For purposes of exposition, we will for the momentconsider the case where q φ ( y , z | x ) = q φ ( y | x , z ) q φ ( z | x ) . For this factorisation, generatingsamples z m,s ∼ q φ ( z | x m , y m ) requires inference, which means we can no longer compute a simpleMonte Carlo estimator by sampling from the unconditioned distribution q φ ( z | x m ) . Moreover, wealso cannot evaluate the density q φ ( z | x m , y m ) .In order to address these difficulties, we re-express the supervised terms in the objective as L sup ( θ, φ ; x m , y m ) = E q φ ( z | x m , y m ) (cid:20) log p ( x m , y m , z ) q φ ( y m , z | x m ) (cid:21) + (1 + α ) log q φ ( y m | x m ) , (4)which removes the need to evaluate q φ ( z | x m , y m ) . We can then use (self-normalised) importancesampling to approximate the expectation. To do so, we sample proposals z m,s ∼ q φ ( z | x m ) fromthe unconditioned encoder distribution, and define the estimator E q φ ( z | x m , y m ) (cid:20) log p θ ( x m , y m , z ) q φ ( y m , z | x m ) (cid:21) (cid:39) S S (cid:88) s =1 w m,s Z m log p θ ( x m , y m , z m,s ) q φ ( y m , z m,s | x m ) , (5)where the unnormalised importance weights w m,s and normaliser Z m are defined as w m,s := q φ ( y m , z m,s | x m ) q φ ( z m,s | x m ) , Z m = 1 S S (cid:88) s =1 w m,s . (6)To approximate log q φ ( y m | x m ) , we use a Monte Carlo estimator of the lower bound that is normallyused in maximum likelihood estimation, log q φ ( y m | x m ) ≥ E q φ ( z | x m ) (cid:20) log q φ ( y m , z | x m ) q φ ( z | x m ) (cid:21) (cid:39) S S (cid:88) s =1 log w m,s , (7)using the same samples z m,s and weights w m,s as in Eq. (5). When we combine the terms in Eqs. (5)and (7), we obtain the estimator ˆ L sup ( θ, φ ; x m , y m ) := 1 S S (cid:88) s =1 w m,s Z m log p θ ( x m , y m , z m,s ) q φ ( y m , z m,s | x m ) + (1 + α ) log w m,s . (8)4e note that this estimator applies to any conditional dependence structure. Suppose that we were todefine an encoder q φ ( z , y , z | x ) with factorisation q φ ( z | y , z , x ) q φ ( y | z , x ) q φ ( z | x ) .If we propose z ∼ q φ ( z | y , z , x ) and z ∼ q φ ( z | x ) , then the importance weights w m,s forthe estimator in Eq. (8) are defined as w m,s := q φ ( z m,s , y m , z m,s | x m ) q φ ( z m,s | y m , z m,s , x m ) q φ ( z m,s | x m ) = q φ ( y m | z m,s , x m ) . In general, the importance weights are simply the product of conditional probabilities of the supervisedvariables y in the model. Note that this also applies to the models in Kingma et al. [17], whoseobjective we can recover by taking the weights to be constants w m,s = q φ ( y m | x m ) .We can also define an objective analogous to the one used in importance-weighted autoencoders [2],in which we compute the logarithm of a Monte Carlo estimate, rather than the Monte Carlo estimateof a logarithm. This objective takes the form ˆ L sup,iw ( θ, φ ; x m , y m ) := log (cid:34) S S (cid:88) s =1 p θ ( x m , y m , z m,s ) q φ ( z m,s | x m ) (cid:35) + α log (cid:34) S S (cid:88) s =1 w m,s (cid:35) , (9)which can be derived by moving the sums in Eq. (8) into the logarithms and applying the substitution w m,s /q φ ( y m , z m,s | x m ) = 1 /q φ ( z m,s | x m ) . To perform gradient ascent on the objective in Eq. (8), we map the graphical models for p θ ( x , y , z ) and q φ ( y , z | x ) onto a stochastic computation graph in which each stochastic node forms a sub-graph.Figure 1 shows this expansion for the simple VAE for MNIST digits from [16]. In this model, y is adiscrete variable that represents the underlying digit, our latent variable of interest, for which we havepartial supervision data. An unobserved Gaussian-distributed variable z captures the remainder of thelatent information. This includes features such as the hand-writing style and stroke thickness. In thegenerative model (Fig. 1 top-left), we assume a factorisation p θ ( x , y , z ) = p θ ( x | y , z ) p ( y ) p ( z ) inwhich y and z are independent under the prior. In the recognition model (Fig. 1 bottom-left), we usea conditional dependency structure q φ ( y , z | x ) = q φ z ( z | y , x ) q φ y ( y | x ) to disentangle the digitlabel y from the handwriting style z (Fig. 1 right).The generative and recognition model are jointly form a stochastic computation graph (Fig. 1 centre)containing a sub-graph for each stochastic variable. These can correspond to fully supervised,partially supervised and unsupervised variables. This example graph contains three types of sub-graphs, corresponding to the three possibilities for supervision and gradient estimation: • For the fully supervised variable x , we compute the likelihood p under the generative model, thatis p θ ( x | y , z ) = N ( x ; η θ ( y , z )) . Here η θ ( y , z ) is a neural net with parameters θ that returnsthe parameters of a normal distribution (i.e. a mean vector and a diagonal covariance). • For the unobserved variable z , we compute both the prior probability p ( z ) = N ( z ; η z ) , and theconditional probability q φ ( z | x , y ) = N ( z ; λ φ z ( x , y )) . Here the usual reparametrisation isused to sample z from q φ ( z | x , y ) by first sampling (cid:15) ∼ N ( , I ) using the usual reparametrisa-tion trick z = g ( (cid:15), λ φ ( x , y )) . • For the partially observed variable y , we also compute probabilities p ( y ) = Discrete ( y ; η y ) and q φ y ( y | x ) = Discrete ( y ; λ φ z ( x )) . The value y is treated as observed when available, and sampledotherwise. In this particular example, we sample y from a q φ y ( y | x ) using a Gumbel-softmax[12, 24] relaxation of the discrete distribution.The example in Fig. 1 illustrates a general framework for defining VAEs with arbitrary dependencystructures. We begin by defining a node for each random variable. For each node we then specifya distribution type and parameter function η , which determines how the probability under thegenerative model depends on the other variables in the network. This function can be a constant, fullydeterministic, or a neural network whose parameters are learned from the data. For each unsupervisedand semi-supervised variable we must additionally specify a function λ that returns the parametervalues in the recognition model, along with a (reparametrised) sampling procedure.Given this specification of a computation graph, we can now compute the importance samplingestimate in Eq. (8) by simply running the network forward repeatedly to obtain samples from q φ ( ·| λ ) for all unobserved variables. We then calculate p θ ( x , y , z ) , q φ ( y | x ) , q φ ( y , z | x ) , and the importanceweight w , which is the joint probability of all semi-supervised variable for which labels are available.This estimate can then be optimised with respect to the variables θ and φ to train the autoencoder.5a) (b) (c) (d)Figure 2: (a) Visual analogies for the MNIST data, partially supervised with just 100 labels (out of50000). We infer the style variable z and then vary the label y . (b) Exploration in style space withlabel y held fixed and (2D) style z varied. Visual analogies for the SVHN data when (c) partiallysupervised with just 1000 labels, and (d) fully supervised. We evaluate our framework along a number of different axes pertaining to its ability to learn disen-tangled representations through the provision of partial graphical-model structures for the latentsand weak supervision. In particular, we evaluate its ability to (i) function as a classifier/regressor forparticular latents under the given dataset, (ii) learn the generative model in a manner that preservesthe semantics of the latents with respect to the data generated, and (iii) perform these tasks, in aflexible manner, for a variety of different models and data.For all the experiments run, we choose architecture and parameters that are considered standardfor the type and size of the respective datasets. Where images are concerned (with the exceptionof MNIST), we employ (de)convolutional architectures, and employ a standard GRU recurrencein the Multi-MNIST case. For learning, we used AdaM [15] with a learning rate and momentum-correction terms set to their default values. As for the mini batch sizes, they varied from 100-700depending on the dataset being used and the sizes of the labelled subset D sup . All of the above,including further details of precise parameter values and the source code, including our PyTorch-based library for specifying arbitrary graphical models in the VAE framework, is available at – https://github.com/probtorch/probtorch . We begin with an experiment involving a simple dependency structure, in fact the very same as thatin Kingma et al. [17], to validate the performance of our importance-sampled objective in the specialcase where the recognition network and generative models factorise as indicated in Fig. 1(left), givingus importance weights that are constant w m,s = q φ ( y m | x m ) . The model is tested on it’s ability toclassify digits and perform conditional generation on the MNIST and Google Street-View HouseNumbers (SVHN) datasets. As Fig. 1(left) shows, the generative and recognition models have the“digit” label, denoted y , partially specified (and partially supervised) and the “style” factor, denoted z , assumed to be an unobserved (and unsupervised) variable.Figure 2(a) and (c) illustrate the conditional generation capabilities of the learned model, where weshow the effect of first transforming a given input (leftmost column) into the disentangled latentspace, and with the style latent variable fixed, manipulating the digit through the generative model togenerate data with expected visual characteristics. Note that both these results were obtained withpartial supervision – 100 (out of 50000) labelled data points in the case of MNIST and 1000 (outof 70000) labelled data points in the case of SVHN. The style latent variable z was taken to be adiagonal-covariance Gaussian of 10 and 15 dimensions respectively. Figure 2(d) shows the same forSVHN with full supervision. Figure 2(b) illustrates the alternate mode of conditional generation,where the style latent, here taken to be a 2D Gaussian, is varied with the digit held fixed.Next, we evaluate our model’s ability to effectively learn a classifier from partial supervision. Wecompute the classification error on the label-prediction task on both datasets, and the results arereported in the table in Fig. 3. Note that there are a few minor points of difference in the setupbetween our method and those we compare against [17]. We always run our models directly on thedata, with no pre-processing or pre-learning on the data. Thus, for MNIST, we compare against6 MNIST M =100 M =600 M =1000 M =3000 SVHN M =1000 M =3000 Supervision Rate ( ) C l a ss i f i c a t i on E rr o r ( % ) Effect of Supervision Rate ( ) M N I S T N = M Ours M2 [17]100 9.71 ( ± ± ± ± ± ± ± ± S VHN N = M Ours M1 + M2 [17]1000 38.91 ( ± ± ± Right:
Classification error rates for different labelled-set sizes M over multiple runs,with supervision rate ρ = γMN + γM , γ = 1 . For SVHN, we compare against a multi-stage process(M1 + M2) [17], where our model only uses a single stage.
Left:
Classification error over differentlabelled set sizes and supervision rates for MNIST (top) and SVHN (bottom). Here, scaling of theclassification objective is held fixed at α = 50 (MNIST) and α = 70 (SVHN). Note that for sparselylabelled data ( M (cid:28) N ), a modicum of over-representation ( γ > ) helps improve generalisationwith better performance on the test set. Conversely, too much over-representation leads to overfitting.model M2 from the baseline which does just the same. However, for SVHN, the baseline methoddoes not report errors for the M2 model; only the two-stage M1 + M2 model which involves a separatefeature-extraction step on the data before learning a semi-supervised classifier.As the results indicate, our model and objective does indeed perform on par with the setup consideredin Kingma et al. [17], serving as basic validation of our framework. We note however, that fromthe perspective of achieving the lowest possible classification error, one could adopt any number ofalternate factorisations [23] and innovations in neural-network architectures [26, 32].
Supervision rate:
As discussed in Section 2.1, we formulate our objective to provide a handleon the relative weight between the supervised and unsupervised terms. For a given unsupervisedset size N , supervised set size M , and scaling term γ , the relative weight is ρ = γM/ ( N + γM ) .Figure 3 shows exploration of this relative weight parameter over the MNIST and SVHN datasetsand over different supervised set sizes M . Each line in the graph measures the classification errorfor a given M , over ρ , starting at γ = 1 , i.e. ρ = M/ ( N + M ) . In line with Kingma et al.[17], weuse α = 0 . /ρ . When the labelled data is very sparse ( M (cid:28) N ), over-representing the labelledexamples during training can help aid generalisation by improving performance on the test data. Inour experiments, for the most part, choosing this factor to be ρ = M/ ( N + M ) provides good results.However, as is to be expected, over-fitting occurs when ρ is increased beyond a certain point. We next move to a more complex domain involving generative models of faces. Here, we use the“Yale B” dataset [5] as processed by Jampani et al. [11] for the results in Fig. 4. As can be seen in thegraphical models for this experiment in Fig. 5, the dependency structures employed here are morecomplex in comparison to those from the previous experiment. We are interested in showing that ourmodel can learn disentangled representations of identity and lighting and evaluate it’s performanceon the tasks of (i) classification of person identity, and (ii) regression for lighting direction.Note that our generative model assumes no special structure – we simply specify a model where alllatent variables are independent under the prior. Previous work [11] assumed a generative modelwith latent variables identity i , lighting l , shading s , and reflectance r , following the relationship ( n · l ) × r + (cid:15) for the pixel data. Here, we wish to demonstrate that our generative model still learnsthe correct relationship over these latent variables, by virtue of the structure in the recognition modeland given (partial) supervision.Note that in the recognition model (Fig. 5), the lighting l is a latent variable with continuous domain,and one that we partially supervise. Further, we encode identity i as a categorical random variable,instead of constructing a pixel-wise surface-normal map (each assumed to be independent Gaussian)as is customary. This formulation allows us to address the task of predicting identity directly, insteadof applying surrogate evaluation methods (e.g. nearest-neighbour classification based on inferredreflectance). Figure 4 presents both qualitative and quantitative evaluation of the framework to jointlylearn both the structured recognition model, and the generative model parameters.7 nput Recon. Varying IdentityInput Recon. Varying Lighting Identity LightingOurs(Full Supervision) 1.9% ( ± ± ± ± ≈ ≈ Figure 4:
Left:
Exploring the generative capacity of the supervised model by manipulating identityand lighting given a fixed (inferred) value of the other latent variables.
Right:
Classification andregression error rates for identity and lighting latent variables, fully-supervised, and semi-supervised(with 6 labelled example images for each of the 38 individuals, a supervision rate of ρ = 0 . ,and α = 10 ). Classification is a direct 1-out-of-38 choice, whereas for the comparison, error is anearest-neighbour loss based on the inferred reflectance. Regression loss is angular distance. Finally, we conduct an experiment that extends the complexity from the prior models even further.Particularly, we explore the capacity of our framework to handle models with stochastic dimension-ality – having the number of latent variables itself determined by a random variable, and modelsthat can be composed of other smaller (sub-)models. We conduct this experiment in the domain ofmulti-MNIST. This is an apposite choice as it satisfies both the requirements above – each image canhave a varying number of individual digits, which essentially dictates that the model must learn tocount, and as each image is itself composed of (scaled and translated) exemplars from the MNISTdata, we can employ the MNIST model itself within the multi-MNIST model.The model structure that we assume for the generative and recognition networks is shown in Fig. 5.We extend the models from the MNIST experiment by composing it with a stochastic sequencegenerator, in which the loop length K is a random variable. For each loop iteration k = 1 , . . . , K ,the generative model iteratively samples a digit y k , style z k , and uses these to generate a digit image x k in the same manner as in the earlier MNIST example. Additionally, an affine tranformation is alsosampled for each digit in each iteration to transform the digit images x k into a common, combinedcanvas that represents the final generated image x , using a spatial transformer network [10].In the recognition model, we predict the number of digits K from the pixels in the image. For eachloop iteration k = 1 , . . . , K , we define a Bernoulli-distributed digit image x k . When supervision isavailable, we compute the probability of x k from the binary cross-entropy in the same manner as inthe likelihood term for the MNIST model. When no supervision is available, we deterministically set x k to the mean of the distribution. This can be seen akin to providing bounding-boxes around theconstituent digits as supervision for the labelled data, which must be taken into account when learningthe affine transformations that decompose a multi-MNIST image into its constituent MNIST-likeimages. This model design is similar to the one used in DRAW [9], recurrent VAEs [3], and AIR [4].In the absence of a canonical multi-MNIST dataset, we created our own from the MNIST dataset bymanipulating the scale and positioning of the standard digits into a combined canvas, evenly balancedacross the counts (1-3) and digits. We then conducted two experiments within this domain. In thefirst experiment, we seek to measure how well the stochastic sequence generator learns to counton its own, with no heed paid to disentangling the latent representations for the underlying digits. Intrinsic Faces Multi-MNIST x i (cid:96)s r x i(cid:96) rs xa k x k z k y k K K x K a k x k h k z k y k h k − K Generative Model Recognition Model Generative Model Recognition ModelFigure 5: Generative and recognition models for the intrinsic-faces and multi-MNIST experiments.8 nput Reconstruction Decomposition MM + N Count Error (%)w/o MNIST w/ MNIST0.1 85.45 ( ± ± ± ± ± ± Figure 6:
Left:
Example input multi-MNIST images and reconstructions.
Top-Right : Decompositionof Multi-MNIST images into constituent MNIST digits.
Bottom-Right:
Count accuracy overdifferent supervised set sizes M for given dataset size M + N = 82000 .Here, the generative model presumes the availability of individual MNIST-digit images, generatingcombinations under sampled affine transformations. In the second experiment, we extend the abovemodel to now also incorporate the same pre-trained MNIST model from the previous section, whichallows the generative model to sample MNIST-digit images, while also being able to predict theunderlying digits. This also demonstrates how we can leverage compositionality of models: whena complex model has a known simpler model as a substructure, the simpler model and its learnedweights can be dropped in directly.The count accuracy errors across different supervised set sizes, reconstructions for a random set ofinputs, and the decomposition of a given set of inputs into their constituent individual digits, areshown in Fig. 6. All reconstructions and image decompositions shown correspond to the nested-modelconfiguration. We observe that not only are we able to reliably infer the counts of the digits in thegiven images, we are able to simultaneously reconstruct the inputs as well as its constituent parts.
In this paper we introduce a framework for learning disentangled representations of data usingpartially-specified graphical model structures and semi-supervised learning schemes in the domain ofvariational autoencoders (VAEs). This is accomplished by defining hybrid generative models whichincorporate both structured graphical models and unstructured random variables in the same latentspace. We demonstrate the flexibility of this approach by applying it to a variety of different tasksin the visual domain, and evaluate its efficacy at learning disentangled representations in a semi-supervised manner, showing strong performance. Such partially-specified models yield recognitionnetworks that make predictions in an interpretable and disentangled space, constrained by the structureprovided by the graphical model and the weak supervision.The framework is implemented as a PyTorch library [25], enabling the construction of stochasticcomputation graphs which encode the requisite structure and computation. This provides anotherdirection to explore in the future — the extension of the stochastic computation graph framework toprobabilistic programming [8, 34, 35]. Probabilistic programs go beyond the presented framework topermit more expressive models, incorporating recursive structures and higher-order functions. Thecombination of such frameworks with neural networks has recently been studied in Le et al. [22] andRitchie et al. [28], indicating a promising avenue for further exploration.
Acknowledgements
This work was supported by the EPSRC, ERC grant ERC-2012-AdG 321162-HELIOS, EPSRC grantSeebibyte EP/M013774/1, and EPSRC/MURI grant EP/N019474/1. BP & FW were supported byThe Alan Turing Institute under the EPSRC grant EP/N510129/1. FW & NDG were supported underDARPA PPAML through the U.S. AFRL under Cooperative Agreement FA8750-14-2-0006. FW wasadditionally supported by Intel and DARPA D3M, under Cooperative Agreement FA8750-17-2-0093.9 eferences [1] Jonathan Berant, Vivek Srikumar, Pei-Chun Chen, Abby Vander Linden, Brittany Harding, BradHuang, Peter Clark, and Christopher D Manning. Modeling biological processes for readingcomprehension. In
EMNLP , 2014.[2] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519 , 2015.[3] Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and YoshuaBengio. A recurrent latent variable model for sequential data. In
Advances in neural informationprocessing systems , pages 2980–2988, 2015.[4] S. M. Ali Eslami, Nicolas Heess, Theophane Weber, Yuval Tassa, Koray Kavukcuoglu, andGeoffrey. E Hinton. Attend, infer, repeat: Fast scene understanding with generative models. arXiv preprint arXiv:1603.08575 , 2016.[5] A.S. Georghiades, P.N. Belhumeur, and D.J. Kriegman. From few to many: Illumination conemodels for face recognition under variable lighting and pose.
IEEE Trans. Pattern Anal. Mach.Intelligence , 23(6):643–660, 2001.[6] Samuel Gershman and Noah Goodman. Amortized inference in probabilistic reasoning. In
CogSci , 2014.[7] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, SherjilOzair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In
Advances in NeuralInformation Processing Systems , pages 2672–2680, 2014.[8] ND Goodman, VK Mansinghka, D Roy, K Bonawitz, and JB Tenenbaum. Church: A languagefor generative models. In
Uncertainty in Artificial Intelligence , pages 220–229, 2008.[9] Karol Gregor, Ivo Danihelka, Alex Graves, Danilo Rezende, and Daan Wierstra. Draw: Arecurrent neural network for image generation. In
Proceedings of the 32nd InternationalConference on Machine Learning (ICML-15) , pages 1462–1471, 2015.[10] Max Jaderberg, Karen Simonyan, Andrew Zisserman, et al. Spatial transformer networks. In
Advances in Neural Information Processing Systems , pages 2017–2025, 2015.[11] Varun Jampani, S. M. Ali Eslami, Daniel Tarlow, Pushmeet Kohli, and John Winn. Consensusmessage passing for layered graphical models. In
International Conference on ArtificialIntelligence and Statistics , pages 425–433, 2015.[12] E. Jang, S. Gu, and B. Poole. Categorical reparameterization with gumbel-softmax. arXivpreprint arXiv:1611.01144 , 2016.[13] Matthew Johnson, David K Duvenaud, Alex Wiltschko, Ryan P Adams, and Sandeep R Datta.Composing graphical models with neural networks for structured representations and fastinference. In
Advances in Neural Information Processing Systems , pages 2946–2954, 2016.[14] Matthew J. Johnson, David K. Duvenaud, Alex B. Wiltschko, Sandeep R. Datta, and Ryan P.Adams. Composing graphical models with neural networks for structured representations andfast inference. In
Advances in Neural Information Processing Systems , 2016.[15] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization.
CoRR ,abs/1412.6980, 2014. URL http://arxiv.org/abs/1412.6980 .[16] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. In
Proceedings of the2nd International Conference on Learning Representations , 2014.[17] Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised learning with deep generative models. In
Advances in Neural Information ProcessingSystems , pages 3581–3589, 2014.[18] Daphne Koller and Nir Friedman.
Probabilistic graphical models: principles and techniques .MIT press, 2009. 1019] Tejas D Kulkarni, Pushmeet Kohli, Joshua B Tenenbaum, and Vikash Mansinghka. Picture:A probabilistic programming language for scene perception. In
Proceedings of the IEEEConference on Computer Vision and Pattern Recognition , pages 4390–4399, 2015.[20] Tejas D Kulkarni, William F Whitney, Pushmeet Kohli, and Josh Tenenbaum. Deep convolu-tional inverse graphics network. In
Advances in Neural Information Processing Systems , pages2530–2538, 2015.[21] Steffen L Lauritzen and David J Spiegelhalter. Local computations with probabilities ongraphical structures and their application to expert systems.
Journal of the Royal StatisticalSociety. Series B (Methodological) , pages 157–224, 1988.[22] Tuan Anh Le, Atilim Gunes Baydin, and Frank Wood. Inference compilation and universalprobabilistic programming. arXiv preprint arXiv:1610.09900 , 2016.[23] L. Maaløe, C. K. Sønderby, S. K. Sønderby, and O. Winther. Auxiliary deep generative models. arXiv preprint arXiv:1602.05473 , 2016.[24] C. J. Maddison, A. Mnih, and Y. W. Teh. The concrete distribution: A continuous relaxation ofdiscrete random variables. arXiv preprint arXiv:1611.00712 , 2016.[25] PyTorch. PyTorch. http://pytorch.org/ , 2017. Accessed: 2017-11-4.[26] A. Rasmus, H. Valpola, M. Honkala, M. Berglund, and Raiko. T. Semi-supervised learning withladder networks. In
Advances in Neural Information Processing Systems , pages 3532–3540,2015.[27] Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagationand approximate inference in deep generative models. In
Proceedings of The 31st InternationalConference on Machine Learning , pages 1278–1286, 2014.[28] Daniel Ritchie, Paul Horsfall, and Noah D Goodman. Deep amortized inference for probabilisticprograms. arXiv preprint arXiv:1610.05735 , 2016.[29] John Schulman, Nicolas Heess, Theophane Weber, and Pieter Abbeel. Gradient estimationusing stochastic computation graphs. In
Advances in Neural Information Processing Systems ,pages 3510–3522, 2015.[30] N. Siddharth, A. Barbu, and J. M. Siskind. Seeing what you’re told: Sentence-guided activityrecognition in video. In
Proceedings of the IEEE Conference on Computer Vision and PatternRecognition (CVPR) , pages 732–39, June 2014.[31] Kihyuk Sohn, Honglak Lee, and Xinchen Yan. Learning structured output representation usingdeep conditional generative models. In
Advances in Neural Information Processing Systems ,pages 3465–3473, 2015.[32] C. K. Sønderby, T. Raiko, L. Maaløe, S. K. Sønderby, and O. Winther. Ladder variationalautoencoders. In
Advances in Neural Information Processing Systems , 2016.[33] Andreas Stuhlmüller, Jacob Taylor, and Noah Goodman. Learning stochastic inverses. In
Advances in neural information processing systems , pages 3048–3056, 2013.[34] David Wingate, Andreas Stuhlmueller, and Noah D Goodman. Lightweight implementationsof probabilistic programming languages via transformational compilation. In
InternationalConference on Artificial Intelligence and Statistics , pages 770–778, 2011.[35] Frank Wood, Jan Willem van de Meent, and Vikash Mansinghka. A new approach to prob-abilistic programming inference. In