A Tutorial on the Mathematical Model of Single Cell Variational Inference
AA T
UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
Songting Shi
Department of Scientific and Engineering ComputingSchool of Mathematical SciencesPeking UniversityBeijing 300071, P. R. China [email protected]
January 5, 2021 A BSTRACT
As the large amount of sequencing data accumulated in past decades and it is still accumulating,we need to handle the more and more sequencing data. As the fast development of the computingtechnologies, we now can handle a large amount of data by a reasonable of time using the neuralnetwork based model. This tutorial will introduce the the mathematical model of the single cellvariational inference (scVI), which use the variational auto-encoder (building on the neural networks)to learn the distribution of the data to gain insights. It was written for beginners in the simple andintuitive way with many deduction details to encourage more researchers into this field.As the computer technology evolves rapidly, we can tackle more and more complex problem by finding a suitablefunction taking millions of parameters to model the key part of the problems. The deep neural network(Lecun et al.(2015), Goodfellow et al. (2016) is the top representer of such a function, and it has achieved success in the manyfileds, such as natural language processing, image processing, game and so on, now it also gets into the computationalbiology, e.g. Alphafold(Senior et al. (2020)). This paper will introduce the single cell variational inference modelscVI(Romain et al. (2018)), which use the variational auto-encoder equipped with the deep neural networks to tacklethe data integration problem and downstream analysis on the scRNA-seq data.The scVI model use the variational auto-encoder to extract information from the gene expression data. Thevariational auto-encoder consists of the variational encoder and probability decoder, where the encoder will encode theexpression data to a continuous hidden low dimensional space in a compact form such that cells from the cell type willclose to each other even cells coming from different batches while different cell type separate; the decoder will decodethe "code" in the common low dimensional space into the original space, it was designed to separate out the "dropout"effects in the sequencing to get the clean expressed data, which can be used to do imputation and to find differentialexpressed genes. The "code" in the low dimensional space of the cell can be used to do clustering, annotation, andvisualization.The basic idea of the variational auto-encoder is to learn the distribution of the gene expression data by assumethat the expression data was generated by two staged processes, the first stage is to sample a code (may be view as theidentifier of the cell) from the prior distribution on the low dimensional space, the second stage is to sample a geneexpression from a conditional distribution based on the code of the cell. It design a variational decoder to decode thegene expression of a cell into its code, and also a probability decoder to decode the code into the gene expression of thecell. The parameters of the probability distribution of decoder and encoder were output by the deep neural networks.By approximating the log-likelihood by the variational lower lower bound which were calculated from the parametersof the encoder and decoder, we can maximized the variational lower bound to approach the maximum log-likelihood ofthe observed data, which yields the approximate best probability decoder and encoder. Having the decoder and encoder,we can do the downstream analysis on the data, e.g. clustering, annotation, and visualization, imputation, differentialexpressed genes and so on.To make an intuitive understand of the scVI model, we will introduce the auto-encoder in Section 1 and variationauto encoder in Section 2. If you are familiar with the auto-ender and variational auto-encoder, please go directly intothe Section 3 for the mathematical model of scVI. a r X i v : . [ q - b i o . O T ] J a n T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
To understand the scVI model, we should first understand how the variational auto-encoder works. And to make theunderstanding the variational auto-encoder easier, we first introduce the auto-encoder(Bengio (2009)) which is a similarbut simple model. Now, let we think a simple example to get the ideas of auto-encoder. Suppose that the hidden code z ∼ Normal ( , ) and the data is generated by = g ( z ) = [ z, z ] . We have ∼ Normal ([ , ] , [ ,
2; 2 , ]) (1)formally, where the variance matrix of = (cid:149) (cid:152) (2)is not invertible, and is a degenerated normal distribution. Suppose that we see a set of samples of , X : = − . − − . − − . − . − . − .
20 00 . . . . . (3)from the above generation process. While now suppose that we only see the set of examples, we do not know theunderline generation mechanism. We want to learn an encode (contraction) function ƒ ( · ) such that we can represent in a compact form by z = ƒ ( ) and also a decode function g ( · ) such that we can recover form its code z , i.e, = g ( z ) . How can we do this?By a simple linear regression, we can easily get the relation = . This means that lines on a one-dimensional manifold, we can easily find the the contraction function z = ƒ ( ) = and recovery the from code z by the generation function = g ( z ) = [ z, z ] . For this simple example, a simple guess solves this problem.Can we find an algorithm from the above process to formulate a general method to solve this kind of problem butwith more complicated data? Yes! Auto-encoder is one of such a method. It is a framework to learn the generationfunction g ( · ) and encode (contraction) function ƒ ( · ) . Now we apply the auto-encoder method to go through this simpleexample to gain the basic ideas. We now suppose that g and ƒ comes from the linear transform function, and lies onthe one-dimension manifold, we can parametrize function ƒ ( , , ) = + and g ( z ) = [ b z, b z ] .Then the auto-encoder will output ˜ : = g · ƒ ( ) = [ b z, b z ] = [ b ( + ) , b ( + )] . Theobjective function of auto-encoder is given by N N (cid:88) n = || ˜ n − n || (4) L ( X, , , b , b ) : = N N (cid:88) n = ( n, − b ( n, + n, )) + ( n, − b ( n, + n, )) (5)And then using the SGD methods or its variants to train the model on the training data to minimize the objectivefunction (5 ).Question: Can we learn out the optimal solution = , = , b = , b = ? or some other reasonablesolution? yes!Since this objection is differentiable, we can set the first order of the objective function to get the stable condition. ∂L∂ = N (cid:80) Nn = − b n, ( n, − b ( n, + n, )) − b n, ( n, − b ( n, + n, )) = ∂L∂ = N (cid:80) Nn = − b n, ( n, − b ( n, + n, )) − b n, ( n, − b ( n, + n, )) = ∂L∂b = N (cid:80) Nn = − ( n, + n, )( n, − b ( n, + n, )) = ∂L∂b = N (cid:80) Nn = − ( n, + n, )( n, − b ( n, + n, )) = (6)2 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE when we fix the b , b , we can get the following linear system about , . (cid:150) (cid:80) n ( b + b ) n, (cid:80) n ( b + b ) n, n, (cid:80) n ( b + b ) n, n, (cid:80) n ( b + b ) n, (cid:153) (cid:149) (cid:152) = (cid:150) (cid:80) n b n, + b n, n, (cid:80) n b n, n, + b n, (cid:153) (7)Note that n, = n, , we can simplify it to the follow equation (cid:20) ( b + b ) ( b + b ) ( b + b ) ( b + b ) (cid:21) (cid:149) (cid:152) = (cid:149) b + b b + b (cid:152) (8)Note that in the above equation, the second equation is a double times of the first equation, we simply get ( b + b )( + ) = b + b (9)When we fix , , we can get the following solution of b , b , b = (cid:80) n n, ( n, + n, ) (cid:80) n ( n, + n, ) b = (cid:80) n n, ( n, + n, ) (cid:80) n ( n, + n, ) (10)Note that n, = n, , we can simplify it to the follow equation b = (cid:80) n ( + ) (cid:80) n ( + ) b = (cid:80) n ( + ) (cid:80) n ( + ) = b (11)Bring them together, we get the following necessary condition of the stable point. ( b + b )( + ) = b + b b = ( + )( + ) b = ( + )( + ) = b (12)It can be simplifed to the following equation b ( + ) = b b = ( + )( + ) b = b (13)When we restrict that + (cid:54) = , b (cid:54) = , we can get b ( + ) = b = b (14)Obviously, = , = , b = , b = satisfy the stable condition (14). Also note that there are infinitesolution of equation (14), e,g. , = , = . , b = , b = , = , = . , b = , b = , andso on. And which solution arrived is depends on the which algorithm been used. As in the variational auto-encoder, werestrict that z approaches the standard normal distribution, which restrict that the z has zero mean and unit variance,then we can get + = , b = , b = . Note that even in this case, there is a freedom in + = ,but it do not influence the output z = + = ( + ) = .When there is a freedom of the optimal parameters in the function, it usually will cause the optimization algorithmunstable since it can jump between the many optimums. If there are many parameters of function than which need to fitthe true solution, it will cause overfitting of the training data which we learning the noise information in the trainingdata in the function which will deviate from the true solution. A general principle is to add a penalty on the objectivefunction to avoid it, and the penalty can the l2/l1 norm of the parameters of the function. We now add the l2 normpenalty on the parameters , , b , b with multiplier λ , it will give the following loss function: N N (cid:88) n = || ˜ n − n || + λ ( + + b + b ) (15)3 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE L ( X, , , b , b , λ ) : = N (cid:80) Nn = { ( n, − b ( n, + n, )) +( n, − b ( n, + n, )) } + λ ( + + b + b ) (16)We carry out the same analysis above. First, we use the first order condition to get the following condition which theparameter must obey when it arrives at a local minimum of the objective function. ∂L∂ = { N (cid:80) Nn = − b n, ( n, − b ( n, + n, )) − b n, ( n, − b ( n, + n, )) } + λ = ∂L∂ = { N (cid:80) Nn = − b n, ( n, − b ( n, + n, )) − b n, ( n, − b ( n, + n, )) } + λ = ∂L∂b = { N (cid:80) Nn = − ( n, + n, )( n, − b ( n, + n, )) } + λb = ∂L∂b = { N (cid:80) Nn = − ( n, + n, )( n, − b ( n, + n, )) } + λb = (17)when we fix the b , b , we can get the following linear system about , . (cid:150) N (cid:80) n ( b + b ) n, + λ N (cid:80) n ( b + b ) n, n, N (cid:80) n ( b + b ) n, n, N (cid:80) n ( b + b ) n, + λ (cid:153) (cid:149) (cid:152) = (cid:150) N (cid:80) n b n, + b n, n, N (cid:80) n b n, n, + b n, (cid:153) (18)Note that n, = n, , and denoting γ : = N (cid:80) n n, , δ : = b + b , we can simplify it to the follow equation (cid:149) γδ + λ γδγδ γδ + λ (cid:152) (cid:149) (cid:152) = (cid:149) ( b + b ) γ ( b + b ) γ (cid:152) (19)We now assume δ : = b + b > , λ > which is true usually, so that the coefficient matrix is invertible. We havethe following solution (cid:149) (cid:152) = ( b + b ) λ ( γδ + λ ) λ (cid:149) γ γ (cid:152) = b + b γδ + λ (cid:149) γ γ (cid:152) (20)Note that if γ is the sample estimation of (cid:69) ∼ N ( , ) = , and if b = , b = and λ ≈ , then we have (cid:149) (cid:152) ≈ (cid:150) (cid:153) (21), which results in + ≈ , this is what we needed. Note that when we add the l2 norm, we focus a optimumpoint form the original on a line to a point, the reason is that the l2 norm add a local convexity on the loss landscape.When we fix , , we can get the following solution of b , b , b = N (cid:80) n n, ( n, + n, ) N (cid:80) n ( n, + n, ) + λ b = N (cid:80) n n, ( n, + n, ) N (cid:80) n ( n, + n, ) + λ (22)Note that n, = n, , we can simplify it to the follow equation b = ( + ) γ ( + ) γ + λ b = ( + ) γ ( + ) γ + λ = b (23)Note that γ ≈ , and if λ ≈ and ≈ / , ≈ / , then we have b ≈ , b ≈ .Bring them together, we get the following necessary condition of the stable point. = b + b γ ( b + b )+ λ γ = b = ( + ) γ ( + ) γ + λ b = b (24)4 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
Under the condition that λ > , we can solve the above equation to get the solutions = (cid:112) (cid:112) γ − λ γ = b = (cid:112) (cid:112) γ − λ (cid:112) γ b = b (25)or = − (cid:112) (cid:112) γ − λ γ = b = − (cid:112) (cid:112) γ − λ (cid:112) γ b = b (26)Now we focus on the positive solution, when λ = , γ = , we have = (cid:112) = (cid:112) b = (cid:112) b = (cid:112) (27)For = , we have z = ƒ ( ) = + = ( + ) = = (cid:112) , and g ( z ) = [ b z, b z ] = [ (cid:112) z, (cid:112) z ] = [ , ] = [ , ] . This verifies the correctness of the solu-tion. But in this case, we have z = (cid:112) ∼ N ( , ) . In generally, if we only have = , then g ( ƒ ( )) = [ b ( + ) , b ( + )] = [ b , b ] = [ (cid:112) γ − λ γ (cid:112) γ , (cid:112) γ − λ γ (cid:112) γ ] .and z = = (cid:112) (cid:112) γ − λγ ∼ N ( , (cid:112) γ − λγ ) . In this case, when γ ≈ and λ ≈ , we will recovery correctly.Under the l2 norm penalty, we reduce the infinite solution of original encoder and decoder to two solutions, and thiswill make the algorithm more stable, if we choose that a small λ , the optimum of the objective with l2 penalty willapproximate one of the optimums of the original solution.What will happen if we apply the l1 penalty to the original objective, we left the exploration to you.Note if we want to constrain the distribution of z to a standard normal distribution, this will meet a obstacle, sincethe distribution of z = ƒ ( ) depends on the distribution of , i.e. if we known the distribution of is p ( ) , then we canget the distribution of z is p ( ƒ − ( z )) | ∂ƒ − ( z ) ∂z | when the ƒ ( · ) is invertible and the determinant | ∂ƒ − ( z ) ∂z | not equal zeroalmost surely. But we do not known the probability distribution of , even we know the p ( ) , p ( ƒ − ( z )) | ∂ƒ − ( z ) ∂z | when the ƒ ( · ) is hard to compute so that we can not use the KL divergence between the distribution of z and the normaldistribution to get a penalty. This yields the need of the variational auto-encoder. Before we give the story of it, we firstsummarize the above simple formulation of auto-encoder to the general auto-encoder.In the general form of auto-encoder, such as use in the image processing, it consists of encoder function z = ƒ ( , θ ƒ ) and decoder function = g ( z, θ g ) , they are represented by the neural network with parame-ters θ ƒ , θ g , respectively. And general form of the neural networks can be represented in the form ƒ ( , θ ƒ ) = σ m ( A m ( σ m − ( A m − ( · · · σ ( A σ ( A + b ) + b )) + b m − )) + b m ) , where σ ( · ) is an element-wise non-linear activation function(e.g., sigmoid, ReLU), A ∈ (cid:82) d × d − is the linear projection matrix and b ∈ (cid:82) d × is the intercept term, it has m − hidden layers and final m -th layer is the output layer, and parmameters θ ƒ : = { A , . . . , A m , b , . . . , b m } . The loss function is given by N N (cid:88) n = || g ( ƒ ( , θ ƒ ) , θ g ) − n || + λ ( || θ ƒ || + || θ g || ) (28)And it is optimized by the SGD algorithm or its variants, and these methods only need the computation of the gradientof loss function on a mini-batch of samples essentially. 5 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
Fig 1. The type of directed graphical model under consideration. Solid lines denote the generative model p θ ( z ) p θ ( | z ) ,dashed lines denote the variational approximation q ϕ ( z | ) to the intractable posterior p θ ( z | ) . The variationalparameters ϕ are learned jointly with the generative model parameters θ . Now, we begin to tell the general story of the variational auto-encoder(Kingma and Welling (2014), Doersch (2016))with general symbol. After that we begin to introduce the scVI model, which is a variational encoder designed for thescRNA-seq data.We now use the same symbols in the Kingma and Welling (2014) to make it more easy to understand. To tacklethe uncomputable probability distribution of z , the variational auto-encoder assume that data point comes from thehidden continuous variable z . z is generated from the probability distribution p θ (cid:63) ( z ) , and then comes from theconditional distribution p θ (cid:63) ( | z ) . And this is represented in the Fig 1 (Kingma and Welling (2014)) with the solidarrow.The probability distribution of is given by p θ ( ) = (cid:82) z p θ ( z ) p θ ( | z ) dz . We hope that we can finda computable distribution p θ ( | z ) , p θ ( z | ) , p θ ( z ) to concisely represented information from the data points { ( ) , = , , · · · , N } , such that we can use these probability distribution to do the downstream analysis. As weknown from the bayesian approach, we can use a probability class to represent the distribution p θ ( z ) , p θ ( | z ) , but themarginal distribution p θ ( ) is hard to obtain in general, so does the conditional distribution p θ ( z | ) = p θ ( z ) p θ ( | z ) p θ ( ) .The variational inference tackles this problem by using the computable distribution q ϕ ( z | ) from the distribution classof p θ ( z ) to approximate the posterior distribution p θ ( z | ) , which is represented by the dashed lines in Fig 1. Toachieve this goal, we need find a computable algorithm to extract information form sample points into the parametrizeddistribution p θ ( z ) , p θ ( | z ) , q ϕ ( z | ) . This can finished by take the maximum likelihood method and do someapproximation, i.e., use the variational lower bound. Now, we begin to give the fundamental deduction of the variationallower bound. Firstly, in the classical maximum likelihood method, we seek the optimal θ (cid:63) which maximize the log-likelihood log p θ ( x ( ) , · · · , x ( N ) ) = (cid:80) N = log p θ ( x ( ) ) . The variational lower bound on the marginal likelihood ofdatapoint is defined by L ( θ, ϕ, x ( ) ) : = log p θ ( x ( ) ) − D KL ( q ϕ ( z | x ( ) ) || p θ ( z | x ( ) )) (29)The D KL ( p ( ) || q ( )) : = (cid:82) p ( ) log p ( ) q ( ) d is KL divergence between two distribution p ( ) , q ( ) , which isnonnegative. The second RHS term basically measure the divergence of approximate from the true posterior. Andsince it is non-negative, we call it a lower bound. We can rewrite the variational lower bound into the known quantities6 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE p θ ( z ) , p θ ( | z ) , q ϕ ( z | ) . L ( θ, ϕ, x ( ) ) = log p θ ( x ( ) ) − (cid:69) q ϕ ( z | x ( ) ) log q ϕ ( z | x ( ) ) p θ ( z | x ( ) ) = (cid:69) q ϕ ( z | x ( ) ) [ log p θ ( x ( ) ) − log q ϕ ( z | x ( ) ) p θ ( z | x ( ) ) ]= (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) ) p θ ( z | x ( ) ) q ϕ ( z | x ( ) ) = (cid:69) q ϕ ( z | x ( ) ) log p θ ( z , x ( ) ) q ϕ ( z | x ( ) ) = (cid:69) q ϕ ( z | x ( ) ) log p θ ( z ) p θ ( x ( ) | z ) q ϕ ( z | x ( ) ) = − (cid:69) q ϕ ( z | x ( ) ) log q ϕ ( z | x ( ) ) p θ ( z ) + (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) | z )= − D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) + (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) | z ) (30)So we get the classical representation of the variational lower bound. L ( θ, ϕ, x ( ) ) = − D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) + (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) | z ) (31)The first RHS term is the KL divergence between the approximate posterior q ϕ ( z | x ( ) ) and the prior distribution p θ ( z ) of the hidden continuous variable z . When q ϕ ( z | x ( ) ) = p θ ( z | x ( ) ) , we have a tight bound. log p θ ( x ( )) = L ( θ, ϕ, x ( ) ) = − D KL ( p θ ( z | x ( ) ) || p θ ( z )) + (cid:69) p θ ( z | x ( ) ) log p θ ( x ( ) | z ) (32)So if we fix the parameter θ , the maximum of the variational lower bound will equal the log-likelihood p θ ( x ( ) ) , whichis achieved by when p θ ( x ( ) | z ) = q ϕ ( z | x ( ) ) . Now suppose that we always achieve such a state, i.e. the variationallower bound equals the marginal log-likelihood, by the maximum likelihood optimization, if we have large enoughnumber of sample points, then the maximum of the log-likelihood will be achieved on the optimum θ (cid:63) . The abovearguments roughly give us a belief that we can optimize the variational lower bound to find the optimum θ (cid:63) , and the p θ (cid:63) ( ) will catch up the underline data distribution.We next should select the proper distribution class with highly representative capacity for the distributions in thevariational lower bound (31) to approximate the true distribution and make the optimization of the variational lowerbound easily and efficiently.Note that if z is a continuous distribution in a d dimensional space, e.g, normal distribution, is a random vector inthe d ≤ d dimensional space, then we can find a function ƒ : (cid:82) d → (cid:82) d such that = ƒ ( z ) also surely(Kingma andWelling (2014)) with the proper complex function ƒ ( · ) . We can conjecture that if the random vector ∈ (cid:82) n , n ≥ d lies on manifold with essentially d dimension, we can also find the function ƒ : (cid:82) d → (cid:82) n , such that = ƒ ( z ) .Now if z ∼ N ( , ) , and is random vector represents the gene expression distribution. Since there are complicatedregulatory network between genes, the function ƒ ( z ) = should represent these complex regulatory networks. Now,the distribution of z can be the simple normal distribution or log normal distribution, or other continuous distribution.To make the KL divergence D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) small, we let the approximate posterior q ϕ ( z | x ( ) ) in the samedistribution class of the distribution of z . For the single cell RNA-seq data, the distribution class of p θ ( x ( ) | z ) choosethe zero-inflated negative binomial distribution.We call the q ϕ ( z | x ( ) ) as the encoder, it encoder the datapoint x ( ) to its "code" z . And we refer p θ ( x ( ) | z ) asthe decoder, it decode the "code" z in the data point x ( ) .Here, we should point out that the complex regulatory networks between genes is modeled mainly by the mean ofthe negative binomial distribution.To get a sense of the final output by a independent Gassional variable with the mean and diagonal variance asa function of random variable z will capture some dependence structure of , we give a simple example. Now let z ∼ N ( , ) is standard normal variable, and p ( | z ) is the conditional density of N ([ z, z ] , diag ( , )) . We canget the p ( ) in a close form. p ( ) = (cid:82) p ( z ) p ( | z ) dz = (cid:82) (cid:112) π exp ( − z ) π exp ( − (( − z ) + ( − z ) )) dz = π (cid:112) exp ( − ( + − )) (33)7 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
So we get that ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) (34)It shows that this simple example will capture the dependence of , with Cov ( , ) = = Cov ( z, z ) = Cov ( μ ( z, θ ) , μ ( z, θ ) ) . So in general form ∼ N ( μ ( z, θ ) , diag ( σ ( z, θ ))) with the mean μ and diagonalvariance σ ( z, θ ) output by nonlinear mapping such as neural networks, then the density p ( ) will capture complexdependence networks. If is the gene expressions, this will capture the complex gene regulatory networks, and thecomplex gene regulatory networks are captured by μ ( z, θ ) . This may be one reason of the success of the scVI model.The variational autoencoder(VAE) model the probability encoder q ϕ ( z | x ( ) ) by modeling the parameters(i.e, themean and diagonal covariance matrix) of the distribution with a nonlinear mapping (e.g. neural networks). p θ ( z ) is theprior distribution usually the basic distribution without parameters θ , e.g. , standard Gauassion variables. And p θ ( z ) comes from the same probability distribution class of q ϕ ( z | x ( ) ) , this will lead a close form of the KL divergence D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) . The probability distribution fo probability decoder p θ ( x ( ) | z ) should be accounts for thedistribution of the real distribution of , e.g. scVI choose the zero-inflated negative binomial distribution for the geneexpression, while the image processing choose the Guassion distribution with diagonal variance. VAE use a nonlinearmapping ( neural networks) to model the parameters of the distribution of p θ ( x ( ) | z ) .To train the neural networks on a large dataset, it use the stochastic optimization to train the model, whichneeds that a low variance estimate of the gradients of the objective function (variational lower bound). In mostcase, the parametric families of distribution of p θ ( z ) will leads an analytical of expression D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) which is the differentiable with parameters ( θ, ϕ ) . While there is some problem with the reconstruction error term (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) | z ) of the variational lower bound. If we use L (cid:80) L = log p θ ( x ( ) | z ( , ) where z ( , ) ∼ q ϕ ( z | x ( ) ) (35)to estimate it, this will cause two problems. The first one is that variance of this estimation is very high, so it willfail the stochastic optimization. And the second one is that we can not differentiate it with parameters ϕ , since thebackward gradient can not pass through a sample z ( , ) to the parameters of the distribution q ϕ ( z | x ( ) ) . To getaround this problem, Kingma and Welling (2014) proposed the reparametrization trick. The trick use the fact thatwe can express the random variable z ∼ q ϕ ( z | x ) by a deterministic function z = ƒ ϕ ( ε, x ) in many cases, where ε is auxiliary random variable with a independent marginal distribution p ( ε ) . For example z ∼ N ( μ, σ ) canbe expressed by z = μ + σε, ε ∼ N ( , ) . As we known that q ϕ ( z | x ( ) ) (cid:81) dz = p ( ε ) (cid:81) dε , so we have (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) | z ) = (cid:69) ε ∼ p ( ε ) log p θ ( x ( ) | ƒ ϕ ( ε, x ( ) )) , which can be estimated by (cid:69) ε ∼ p ( ε ) log p θ ( x ( ) | ƒ ϕ ( ε, x ( ) )) ≈ L (cid:80) L = log p θ ( x ( ) | ƒ ϕ ( ε ( , ) , x ( ) )) where ε ( , ) ∼ p ( ε ) (36)Now this estimate is differentiable with parameters ϕ . The variance of the this estimate is lower since p ( ε ) is anindependent distribution which is not evolved with x , ϕ and it much easier to draw samples from the stationarydistribution p ( ε ) to cover the the probability area than draw the same number samples from q ϕ ( z | x ( ) ) . We cansample only one point ( L = ) due to the low variance of this estimate in many cases.Summarizing the above efforts, we approximate the log-likelihood by the variational lower bound, and we reparam-eterize the reconstruction term of variational lower bound to get an equivalent representation, L ( θ, ϕ, x ( ) ) = − D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) + (cid:69) q ϕ ( z | x ( ) ) log p θ ( x ( ) | z )= − D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) + (cid:69) ε ∼ p ( ε ) log p θ ( x ( ) | ƒ ϕ ( ε, x ( ) )) where z = ƒ ϕ ( ε, x ( ) ) ∼ q ϕ ( z | x ( ) ) , ε ∼ p ( ε ) (37)The equivalent representation of the variational lower bound is approximated by L ( θ, ϕ, x ( ) ) ≈ ˜ L ( θ, ϕ, x ( ) ) : = − D KL ( q ϕ ( z | x ( ) ) || p θ ( z )) + L (cid:80) L = log p θ ( x ( ) | ƒ ϕ ( ε ( , ) , x ( ) )) where ε ( , ) ∼ p ( ε ) (38)The variational auto-encoder will train the parameters of the neural networks with the approximate objective functionon a mini-batch of samples (cid:80) ∈ mini-batch ˜ L ( θ, ϕ, x ( ) ) each time to maximize the approximated log-likelihood with8 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE stochastic optimization methods, e.g. SGD, Adam(Kingma and Ba (2014)), and so on. And it is hopefully that the finalsolution output by the algorithm will approach the true optimum point ( θ (cid:63) , ϕ (cid:63) ) .Now we return to the simple example above to check the power of the variational auto-encoder. Suppose that the ∈ (cid:82) comes from the following generation process. z ∼ N ( , ) ∼ N (cid:129)(cid:149) z z (cid:152) , (cid:149) (cid:152)(cid:139) (39)We have the follow probability density function. p θ (cid:63) ( z ) = (cid:112) π exp ( − z ) p θ (cid:63) ( | z ) = π exp ( − ( − z ) +( − z ) ) p θ (cid:63) ( ) = π (cid:112) exp ( − ( + − )) ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) p θ (cid:63) ( z | ) = p θ (cid:63) ( | z ) p θ (cid:63) ( z ) p θ (cid:63) ( ) = (cid:114) π exp ( − ( z − + ) ) ∼ N ( + , ) (40)We now choose that q ϕ ( z | ) ∼ N ( μ ( , ϕ ) , σ ( , ϕ )) and p θ ( | z ) ∼ N ( μ ( z, θ ) , diag ( σ ( z, θ ))) , where μ ( , ϕ ) ∈ (cid:82) , σ ( , ϕ ) ∈ (cid:82) + are the function of with parameters ϕ , and μ ( z, θ ) ∈ (cid:82) , σ ( z, θ ) ∈ (cid:82) + arethe mapping of variable z with parameters θ .Chosen the model in the Gaussion classes, we can calculate the variational lower bound with analytical expression. L ( θ, ϕ, x ) = − D KL ( q ϕ ( z | x || p θ ( z )) + (cid:69) q ϕ ( z | x ) log p θ ( x | z )= − D KL ( N ( μ ( , ϕ ) , σ ( , ϕ )) || N ( , )) + (cid:69) z ∼ N ( μ ( ,ϕ ) ,σ ( ,ϕ )) log p θ ( | z )= − [ − log σ ( , ϕ ) − + σ ( ,ϕ )+ μ ( ,ϕ ) ]+ (cid:69) z ∼ N ( μ ( ,ϕ ) ,σ ( ,ϕ )) [ − log ( πσ ( z, θ ) σ ( z, θ )) − ( − μ ( z,θ )) σ ( z,θ ) − ( − μ ( z,θ )) σ ( z,θ ) ]= − [ − log σ ( , ϕ ) − + σ ( ,ϕ )+ μ ( ,ϕ ) ]+ (cid:69) ε ∼ N ( , ) [ − log ( πσ ( μ ( , ϕ ) + σ ( , ϕ ) ε, θ ) σ ( μ ( , ϕ ) + σ ( , ϕ ) ε, θ )) − ( − μ ( μ ( ,ϕ )+ σ ( ,ϕ ) ε,θ )) σ ( μ ( ,ϕ )+ σ ( ,ϕ ) ε,θ ) − ( − μ ( μ ( ,ϕ )+ σ ( ,ϕ ) ε,θ )) σ ( μ ( ,ϕ )+ σ ( ,ϕ ) ε,θ ) ] (41)To simplify the complex expression above, we suppose that σ ( , ϕ ) = / , σ ( z, θ ) = , σ ( z, θ ) = , i.e wetake the variance parameter the same as the underline true parameters. And we we get L ( θ, ϕ, x ) = − [ − log (cid:112) − + + μ ( ,ϕ ) ]+ (cid:69) z ∼ N ( μ ( ,ϕ ) , ) [ − log ( π ) − ( − μ ( z,θ )) − ( − μ ( z,θ )) ]= − log ( (cid:112) π ) + − μ ( ,ϕ ) − (cid:69) z ∼ N ( μ ( ,ϕ ) , ) ( − μ ( z,θ )) +( − μ ( z,θ )) (42)We use the μ ( , ϕ ) : = + , μ ( z, θ ) : = [ b z, b z ] to parametrize the mean function, where ϕ =( , ) and θ = ( b , b ) . The above equation can be simplied to L ( θ, ϕ, x ) = − log ( (cid:112) π ) + − ( + ) − b + b − ( b ( + ) − ) +( b ( + ) − ) (43)And we can find the optimal solution of ϕ = ( , ) , θ = ( b , b ) with the following loss if we have samples { n , n = , . . . , N } ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) . min ϕ,θ L ( θ, ϕ ) = N (cid:80) Nn = [ ( n + n ) + b + b + ( b ( n + n ) − n ) +( b ( n + n ) − n ) ] (44)9 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
Comparing with the auto-encoder (AE) loss (16), the above variational auto-encoder(VAE) loss (44) is similar tothe auto-encoder loss. The VAE loss has one more term z n = ( n + n ) than AE loss, this term comes fromwe want to make that the probability encoder q ϕ ( z | ) close to the standard normal distribution, which is the termwe want to do in the AE loss. Note that this inspire us we can add a penalty term β N (cid:80) n || z n || into the loss (16),which will bias to a solution such that the p θ ( z ) close to some normal distribution N ( , F ( β, λ )) where F ( β, λ ) is afunction of β, λ . We may get some satisfactory with this improvement. However, we can not prespecify the coefficient β, λ such that F ( β, λ ) = while this can be accomplished in the variational auto-encoder, since it was deduced fromthe likelihood function. In this simplified case in which we assume that the the q ϕ ( z | ) and p θ ( | z ) has a constantdiagonal variance matrix. It gives us the intuition that the the mean μ ( , ϕ ) of probability encoder q ϕ ( z | ) in the VAEhas a similar effect as the encoder function z = ƒ ( , ϕ ) in the AE, and the the mean μ ( z, θ ) of probability encoder p θ ( | z ) in the VAE has a similar effect as the encoder function z = g ( , θ ) in the AE. I make the conjecture that thisfact is true when we have the probability encoder and decoder coming the Guassion class and has a diagonal variancematrix. The reason behind this phenomenon is that when we have a Gussion vector has a diagonal covariance, e.g. p θ ( | z ) ∼ N ( μ ( z, θ ) , diag ( σ ( z, θ ))) , the complex correlation between variables g , g = , . . . , G is mainlycaptured by the μ ( z, θ ) since g , g = , . . . , G is independent when give the z , and if z ∼ N ( , ) , the μ ( z, θ ) is a random vector and will capture the complex dependence between variables g , g = , . . . , G via the nonlinearmapping function μ ( z, θ ) .Now we begin to solve the above optimization problem with first order condition. ∂L∂ = N (cid:80) Nn = [ n ( n + n ) + b n ( b ( n + n ) − n )+ b n ( b ( n + n ) − n )] = ∂L∂ = N (cid:80) Nn = [ n ( n + n ) + b n ( b ( n + n ) − n )+ b n ( b ( n + n ) − n )] = ∂L∂b = N (cid:80) Nn = [ b + ( n + n )( b ( n + n ) − n )] = ∂L∂b = N (cid:80) Nn = [ b + ( n + n )( b ( n + n ) − n )] = (45)Now, we use the approximation N (cid:80) Nn = n ≈ (cid:69) = , N (cid:80) Nn = n ≈ (cid:69) = , N (cid:80) Nn = n n ≈ (cid:69) = into the above equation, we get ( + ) + b ( b ( + ) − ) + b ( b ( + ) − ) ≈ ( + ) + b ( b ( + ) − ) + b ( b ( + ) − ) ≈ b + b ( + + ) − ( + ) ≈ b + b ( + + ) − ( + ) ≈ (46)I can not solve the above equation with explicit solution since it will evolve a five order equation about , when eliminating b , b . But we can check that the optimal solution ϕ (cid:63) = ( (cid:63) , (cid:63) ) = ( / , , / ) , θ (cid:63) =( b (cid:63) , b (cid:63) ) = ( , ) satisfies the above stable condition exactly. This shows that the variational auto-encoder has thepower to find the true solution. Note that when ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) , the μ ( , ϕ (cid:63) ) = / + / ∼ N ( , / ) is close to the true hidden variable distribution of z ∼ N ( , ) but not the same, this is because that z = μ ( , ϕ (cid:63) ) + ε, ε ∼ N ( , / ) . The μ ( z, θ (cid:63) ) = (cid:149) z z (cid:152) ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) is also close to the truedistribution N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) . If we set the N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) as the true distribution of , then we canre-tell the story. The ˜ = + ε is the measurements of ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) with the measurement error ε ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) (cid:152)(cid:139) , then we can use the variational auto-encoder method to approximate the distribution of , which is given by distribution of the mean function ( z, θ (cid:63) ) of posterior p θ (cid:63) ( | z ) , i.e. p ( ) = (cid:82) = ( z,θ (cid:63) ) p ( z ) dz .And this interpretation matches the the data generation process in practice, such as the scRNA-seq data, the geneexpression has the complex regulatory relations to control the protein production, while the measurement error maybe independent for each gene, i.e. the expression data ˜ ∈ (cid:82) G can be written as ˜ = + ε , where is trueexpression data, and ε is the measurement error, and genes g , g = , , . . . , G have complex regulatory relations,but the ε ( g ) , g = , , . . . , G are independent with each other and also are independent of . So we can usethe variational auto-encoder to model this process to assume that the distribution of measurements ˜ is comes from10 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE a parametric distribution p θ ( | z ) which can be characterized by its mean and diagonal variance when given thehidden continuous variable z where the mean μ ( z, θ ) of p θ ( | z ) give the distribution of , i.e. = μ ( z, θ ) , andthe diagonal variance characterized the independent random error ε . This phenomenon can go through when the p θ ( | z ) in the Gaussion class, i.e, ˜ ∼ p θ ( | z ) ∼ N ( μ ( z, θ ) , diag ( σ ( θ )) , ˜ = μ ( z, θ ) + ε where = μ ( z, θ ) and ε ∼ N ( , diag ( σ ( θ ))) is the independent random errors. When the distribution of is given by the negativebinomial distribution ∼ NB ( dsperson = d ( z, θ ) , m = μ ( z, θ )) in which each gene is independent of otherswhen give z , the above intuition can roughly go through, we can use the μ ( z, θ ) to capture the complex gene regulatoryrelation. But the variance which given by μ ( z, θ ) + μ ( z,θ ) d ( z,θ ) also correlated for different genes, and there does not havean additive independent random noise which give the independent error for genes in this kind of representation, it maybe more resonable to find a discrete count distribution to model independent errors ( independent of the mean, also theerrors are independent for different genes) in the measurement.Now, we generalize the above example a little. We replace the variance of the error of to γ such that we canobserve how the error do influence on the VAE method.Suppose that the ∈ (cid:82) comes from the following generation process. z ∼ N ( , ) = (cid:149) z z (cid:152) ε ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) γγ (cid:152)(cid:139) ˜ = + ε (47)where z is he hidden continuous variable, is the random vector which we are interested in, ε is the measurement errorwhich is independent of z . Note here the samples are comes from ˜ so that we should use the variational model on ˜ ,but we hope to denoise to get the distribution of .We have the follow probability density function. p θ (cid:63) ( z ) = (cid:112) π exp ( − z ) p θ (cid:63) ( | z ) = δ ( − z ) δ ( − z ) p θ (cid:63) ( ˜ | z ) = πγ exp ( − ( ˜ − z ) +( ˜ − z ) γ ) p θ (cid:63) ( ) = (cid:112) π exp ( − ) δ ( − ) p θ (cid:63) ( ˜ ) = π (cid:112) ( + γ ) γ exp ( − ( + γ ) ˜ +( + γ ) ˜ − ˜ γ ( + γ ) ) ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) + γ
22 4 + γ (cid:152)(cid:139) p θ (cid:63) ( z | ) = δ ( z − ) δ ( z − / ) p θ (cid:63) ( z | ˜ ) = p θ (cid:63) ( ˜ | z ) p θ (cid:63) ( z ) p θ (cid:63) ( ˜ ) = (cid:114) π γ + γ exp ( − ( z − ˜ + + γ ) γ + γ ) ∼ N ( ˜ + + γ , γ + γ ) (48)where δ ( · ) is the Dirac delta function, with the property δ ( ) : = (cid:167) ∞ if = otherwise and (cid:82) − δ ( t ) dt = , ∀ > . We now choose that q ϕ ( z | ˜ ) ∼ N ( μ ( ˜ , ϕ ) , σ ( ˜ , ϕ )) and p θ ( ˜ | z ) ∼ N ( μ ( z, θ ) , diag ( σ ( z, θ ))) , where μ ( ˜ , ϕ ) ∈ (cid:82) , σ ( ˜ , ϕ ) ∈ (cid:82) + are the function of ˜ with parameters ϕ , and μ ( z, θ ) ∈ (cid:82) , σ ( z, θ ) ∈ (cid:82) + are themapping of variable z with parameters θ .Chosen the model in the Gaussion classes, we can calculate the variational lower bound with analytical expression. L ( θ, ϕ, ˜x ) = − D KL ( q ϕ ( z | ˜x || p θ ( z )) + (cid:69) q ϕ ( z | ˜x ) log p θ ( ˜x | z )= − D KL ( N ( μ ( ˜ , ϕ ) , σ ( ˜ , ϕ )) || N ( , )) + (cid:69) z ∼ N ( μ ( ˜ ,ϕ ) ,σ ( ˜ ,ϕ )) log p θ ( ˜ | z )= − [ − log σ ( ˜ , ϕ ) − + σ ( ˜ ,ϕ )+ μ ( ˜ ,ϕ ) ]+ (cid:69) z ∼ N ( μ ( ˜ ,ϕ ) ,σ ( ˜ ,ϕ )) [ − log ( πσ ( z, θ ) σ ( z, θ )) − ( ˜ − μ ( z,θ )) σ ( z,θ ) − ( ˜ − μ ( z,θ )) σ ( z,θ ) ]= − [ − log σ ( ˜ , ϕ ) − + σ ( ˜ ,ϕ )+ μ ( ˜ ,ϕ ) ]+ (cid:69) ε z ∼ N ( , ) [ − log ( πσ ( μ ( ˜ , ϕ ) + σ ( ˜ , ϕ ) ε, θ ) σ ( μ ( ˜ , ϕ ) + σ ( ˜ , ϕ ) ε, θ )) − ( ˜ − μ ( μ ( ˜ ,ϕ )+ σ ( ˜ ,ϕ ) ε,θ )) σ ( μ ( ˜ ,ϕ )+ σ ( ˜ ,ϕ ) ε,θ ) − ( ˜ − μ ( μ ( ˜ ,ϕ )+ σ ( ˜ ,ϕ ) ε,θ )) σ ( μ ( ˜ ,ϕ )+ σ ( ˜ ,ϕ ) ε,θ ) ] (49)11 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
To simplify the complex expression above, we suppose that σ ( ˜ , ϕ ) = γ + γ , σ ( z, θ ) = γ, σ ( z, θ ) = γ , i.e wetake the variance parameter the same as the underline true parameters. And we get L ( θ, ϕ, ˜x ) = − [ − log γ + γ − + γ + γ + μ ( ˜ ,ϕ ) ]+ (cid:69) z ∼ N ( μ ( ˜ ,ϕ ) , γ + γ ) [ − log ( πγ ) − ( ˜ − μ ( z,θ )) γ − ( ˜ − μ ( z,θ )) γ ]= − log ( π (cid:112) γ ( + γ )) + ( + γ ) − μ ( ˜ ,ϕ ) − (cid:69) z ∼ N ( μ ( ˜ ,ϕ ) , γ + γ ) ( ˜ − μ ( z,θ )) +( ˜ − μ ( z,θ )) γ (50)We use the μ ( ˜ , ϕ ) : = ˜ + ˜ , μ ( z, θ ) : = [ b z, b z ] to parametrize the mean function, where ϕ =( , ) and θ = ( b , b ) . The above equation can be simplied to L ( θ, ϕ, ˜x ) = − log ( π (cid:112) γ ( + γ )) + ( + γ ) − ( ˜ + ˜ ) − b + b ( + γ ) − ( b ( ˜ + ˜ ) − ˜ ) +( b ( ˜ + ˜ ) − ˜ ) γ (51)And we can find the optimal solution of ϕ = ( , ) , θ = ( b , b ) with the following loss if we have samples { ˜ n , n = , . . . , N } ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) + γ
22 4 + γ (cid:152)(cid:139) . min ϕ,θ L ( θ, ϕ ) = N (cid:80) Nn = [ ( ˜ n + ˜ n ) + b + b ( + γ ) + ( b ( ˜ n + ˜ n ) − ˜ n ) +( b ( ˜ n + ˜ n ) − ˜ n ) γ ] (52)We can solve the above optimization problem with first order condition. ∂L∂ = N (cid:80) Nn = [ ˜ n ( ˜ n + ˜ n ) + b ˜ n ( b ( ˜ n + ˜ n ) − ˜ n )+ b ˜ n ( b ( ˜ n + ˜ n ) − ˜ n ) γ ] = ∂L∂ = N (cid:80) Nn = [ ˜ n ( ˜ n + ˜ n ) + b ˜ n ( b ( ˜ n + ˜ n ) − ˜ n )+ b ˜ n ( b ( ˜ n + ˜ n ) − ˜ n ) γ ] = ∂L∂b = N (cid:80) Nn = [ b + γ + ( ˜ n + ˜ n )( b ( ˜ n + ˜ n ) − ˜ n ) γ ] = ∂L∂b = N (cid:80) Nn = [ b + γ + ( ˜ n + ˜ n )( b ( ˜ n + ˜ n ) − ˜ n ) γ ] = (53)Now, we use the approximation N (cid:80) Nn = ˜ n ≈ (cid:69) ˜ = + γ , N (cid:80) Nn = ˜ n ≈ (cid:69) ˜ = + γ , N (cid:80) Nn = ˜ n ˜ n ≈ (cid:69) ˜ ˜ = into the above equation, we get ( + γ ) + + b [( + γ ) + ] − ( + γ ) b + b [( + γ ) + ] − b γ ≈ + ( + γ ) + b [ +( + γ ) ] − b + b [ +( + γ ) ] − ( + γ ) b γ ≈ b + γ + b [( + γ ) +( + γ ) + ) − [( + γ ) + )] γ ≈ b + γ + b [( + γ ) +( + γ ) + ) − [ +( + γ ) ] γ ≈ (54)I can not solve the above equation with explicit solution since it will evolve a five order equation about , when eliminating out b , b . But we can check that the optimal solution ϕ (cid:63) = ( (cid:63) , (cid:63) ) = ( + γ , + γ ) , θ (cid:63) =( b (cid:63) , b (cid:63) ) = ( , ) satisfies the above stable condition exactly. This shows that the variational auto-encoder has thepower to find the true solution. Note that when ˜ ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) + γ
22 4 + γ (cid:152)(cid:139) , the μ ( ˜ , ϕ (cid:63) ) = + γ ˜ + + γ ˜ ∼ N ( , + γ ) is close to the true hidden variable distribution of z ∼ N ( , ) if the variance γ of the noisetrends to zero. Note that z = μ ( ˜ , ϕ (cid:63) ) + ε z , ε z ∼ N ( , γ + γ ) which means that the noise level in the spaceinfluence on the hidden variable z space, and the impact of noise on the hidden variable is to be proportional withthe noise level on the space. The μ ( z, θ (cid:63) ) = (cid:149) z z (cid:152) capture the true distribution of the data. From the relation ˜ = μ ( z, θ (cid:63) ) + ε = + ε , ε ∼ N (cid:129)(cid:149) (cid:152) , (cid:149) γ γ (cid:152)(cid:139) , we see that we can use the variational auto-encoderto denoise the data, we use the mean of the posterior distribution p θ (cid:63) ( | z ) to model the distribution of the data, i.e., p θ (cid:63) ( ) = (cid:82) z | ≈ (cid:69) ˜ ∼ pθ (cid:63) ( ˜ | z ) ˜ p θ (cid:63) ( z ) dz . 12 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
Fig 2. overview of scVI. Given a gene expression matrix with batch annotations as input, scVI learns a nonlinearembedding of the cells that can be used for multiple analysis tasks. The neural networks used to compute the embeddingand the distribution of gene expression. NN, neural network. ƒ and ƒ h are functional representations of NN5 and NN6,respectively. The above arguments basically introduce the work principle of the variational auto-encoder. Now we arrived the maingoal, to introduce the mathematical model of scVI(Romain et al. (2018)). Let we first get sense of the model with thegraphical abstract in Figure 2.We assume that the expression ng where n the index of cell, and g is the index gene can be generated by thefollowing process, which characterize the probability distribution of expression. z n ∼ Normal ( , ) n ∼ log normal ( μ , σ ) ρ n = ƒ ( z n , s n ) ng ∼ Gamma ( θ g , θ g ρ gn ) y ng ∼ Poisson ( n ng ) h ng ∼ Bernoulli ( + exp ( − ƒ gh ( z n ,s n )) ) ng = (cid:167) y ng if h ng = otherwise (55)From the above generation process, z n ∼ N ( , ) is a standard d -dimensional normal distribution, whichhas a pdf ( π ) d exp ( − || z n || ) . n ∼ log Normal ( μ , σ ) is a log normal distribution with pdf (cid:113) π σ n exp ( − ( log n − ) σ ) which has mean exp ( μ + σ / ) and variance ( e σ − ) e μ + σ . ng ∼ Gamma ( θ g , θ g ρ gn ) is a Gamma distribution with shape parameter θ g and rate parameter θ g ρ gn , its pdf is given by (cid:18) θgρgn (cid:19) θg ( θ g ) θ g − ng exp ( − θ g ρ gn ng ) ng ≥ which has mean ρ gn and variance ( ρ gn ) / θ g , where the gamma function13 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE is defined by ( α ) = (cid:82) ∞0 α − e − d . The Poisson variable y ng ∼ Poisson ( n ng ) has discrete distributionfunction P ( y ng = k ) = ( n ng ) k k ! exp ( − n ng ) , k = , , . . . with an equal variance and mean n ng . TheBernoulli random variable h ng ∼ Bernoulli ( + exp − ƒ gh ( z n ,s n ) ) is {0 , valued discrete random variable with dis-crete distribution function P ( h ng = ) = exp ( − ƒ gh ( z n ,s n ) + exp ( − ƒ gh ( z n ,s n )) and P ( h ng = ) = + exp ( − ƒ gh ( z n ,s n )) which has mean + exp ( − ƒ gh ( z n ,s n )) ) and variance exp ( − ƒ gh ( z n ,s n )( + exp ( − ƒ gh ( z n ,s n ))) .We can get more concise distribution if we integrate out the intermediate variable. The y ng follows a negativebinomial distribution when given the parameter ρ gn , θ, n , and the ng follows a zero inflated binomial distributionwhen given the parameter ρ gn , θ, n , ƒ gn ( z n , s n ) . We now give the deduction of the negative binomial distribution of y ng when we are given ρ gn , θ, n , i.e, we show that Gamma-Poisson mixture will lead a negative binomial distribution. P ( ng | ρ gn , θ g ) = (cid:18) θgρgn (cid:19) θg ( θ g ) θ g − ng exp ( − θ g ρ gn ng ) ng ≥ d ng P ( y ng = k | n , ng ) = ( n ng ) k k ! exp ( − n ng ) P ( y ng = k | n , ρ gn , θ g ) = (cid:82) ng P ( y ng = k | n , ng ) P ( ng | ρ gn , θ g )= (cid:82) ng ( n ng ) k k ! exp ( − n ng ) (cid:18) θgρgn (cid:19) θg ( θ g ) θ g − ng exp ( − θ g ρ gn ng ) ng ≥ d ng = k ! ( θ g ) (cid:18) θgρgn (cid:19) θg kn ( θgρgn + n ) k + θg (cid:82) ∞0 [( θ g ρ gn + n ) ng ] k + θ g − exp ( − ( θ g ρ gn + n ) ng ) d (( θ g ρ gn + n ) ng )= ( k + θ g ) k ! ( θ g ) (cid:18) θgρgn (cid:19) θg kn ( θgρgn + n ) k + θg ∼ NB ( sze = θ g , p = nθgρgn + n )= ( k + θ g ) ( k + ) ( θ g ) (cid:129) θ g θ g + n ρ gn (cid:139) θ g (cid:129) n ρ gn θ g + n ρ gn (cid:139) k ∼ NB ( dsperson = θ g , μ = n ρ gn ) (56)where NB ( n, p ) is the negative binomial distribution with sze = n and prob = p has density ( + n ) ( n ) ! ( − p ) n p (57)where p ∈ ( , ] is the probability of success in the Bernoulli distribution, and = , , . . . , n > representsthe number of success which occur in a sequence of Bernoulli trials (with the probability of success p ) before a targetnumber n of failure is reached, with the probability given in equation (57). The mean is μ = np/ ( − p ) and variance np/ ( − p ) . An alternative parametrization (often used in ecology) is by the mean μ , and size, the dispersionparameter θ , where p = μ/ ( θ + μ ) . The variance is μ + μ / θ in this parametrization.This give us that y ng is negative binomial distribution with sze = dsperson = θ g , p = nθgρgn + n , μ = n ρ gn and variance n ρ gn + ( n ρ gn ) θ g . 14 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
Now, it obvious that ng obeys the zero-inflated negative binomial distribution with probability mass function givenby P ( ng = | n , ρ gn , θ g , ƒ gh ) = P ( ng = | h ng = , n , ρ gn , θ g , ƒ gh ) P ( h ng = | n , ρ gn , θ g , ƒ gh )+ P ( ng = | h ng = , n , ρ gn , θ g , ƒ gh ) P ( h ng = | n , ρ gn , θ g , ƒ gh )= P ( y ng = | n , ρ gn , θ g ) P ( h ng = | n , ρ gn , θ g , ƒ gh )+ P ( h ng = | n , ρ gn , θ g , ƒ gh )= (cid:129) θ g θ g + n ρ gn (cid:139) θ g exp ( − ƒ gh ( z n ,s n ) + exp ( − ƒ gh ( z n ,s n )) + + exp ( − ƒ gh ( z n ,s n )) = exp ( S ( − ƒ gh + θ g log θ g θ g + n ρ gn − S ( − ƒ gh )) P ( ng = k | n , ρ gn , θ g , ƒ gh , k > ) = P ( ng = k | h ng = , n , ρ gn , θ g , ƒ gh , k > ) P ( h ng = | n , ρ gn , θ g , ƒ gh )= P ( y ng = k | n , ρ gn , θ g ) P ( h ng = | ƒ gh )= ( k + θ g ) ( k + ) ( θ g ) (cid:129) θ g θ g + n ρ gn (cid:139) θ g (cid:129) n ρ gn θ g + n ρ gn (cid:139) k exp ( − ƒ gh ( z n ,s n ) + exp ( − ƒ gh ( z n ,s n )) = exp ( − ƒ gh − S ( − ƒ gh ) + θ g log θ g θ g + n ρ gn + k log kθ g + n ρ gn + log ( k + θ g ) ( k + ) ( θ g ) ) (58)where S ( ) : = log ( + e ) is the softplus function. The above probability mass function can be represented in acompact form. P ( ng | ϕ, z n , s n ) = (cid:49) ng = exp ( S ( − ƒ gh + θ g log θ g θ g + n ρ gn − S ( − ƒ gh ))+ (cid:49) ng > exp ( − ƒ gh − S ( − ƒ gh ) + θ g log θ g θ g + n ρ gn + ng log ng θ g + n ρ gn + log ( ng + θ g ) ( ng + ) ( θ g ) ) (59)To get a sense of neural networks function ƒ ( z n , s n ) where z n ∈ (cid:82) d and s n ∈ (cid:82) B is one-hot representation of thebatch id of the cell n , we plot the following cartoon figure 3 to represent it. This network contains one hidden layer with d neurons, h = ReLU ( (cid:80) dj = W ( z ) ,j z n,j + (cid:80) Bj = W ( s ) ,j s n,j + b ) , = , . . . , d where ReLU function is a element-wise function defined by ReLU ( ) = mx ( , ) , the W ( z ) ∈ (cid:82) d × d , W ( s ) ∈ (cid:82) d × B , b ∈ (cid:82) d are the weights ofthe hidden layer. The output layer is built on the hidden layer with a linear mapping then with a Softmax function, i.e. o g = (cid:80) d j = W ( h ) g,j h j + b ( h ) g , g = , . . . , G , ρ gn == o g (cid:80) j e oj , g = , . . . , G , where W ( h ) ∈ (cid:82) G × d , b ( h ) ∈ (cid:82) G arethe weights of the output layer. Finally, ƒ ( z n , s n ) = ρ n and = { W ( z ) , W ( s ) , b, W ( h ) , b ( h ) } . This simpleneural networks capture the basic structure of neural networks, i.e., linear transform followed by nonlinear mapping,and cascade this basic building blocks. In the practice, the neural network can add more hidden layers to model complexdependence between the input variables and output variables. And for an efficiency and the algorithm stability, we canadd batch normalization (Ioffe and Szegedy (2015)) after the linear transform before the nonlinear mapping to make thegradients go through a numerical stable path. And we can also add the Dropout layer (Srivastava et al. (2014)) whichrandom drop out the connection between neurons with some fix probability to avoid overfitting. And we can also adda l2 norm of weights on the variational lower bound to stable the algorithm, which is favorable the weights near theorigin.As we have prepared the basic information to understand the scVI, we now begin to view the whole picture of thescVI. As we see from figure 2. It is a conditional variational auto-encoder, which add a conditional variable batch ID tothe generation model. It is follows the same logic as the variational autoencoder with a minimal change, which is thatwe only need to make all the probability conditioned on the batch id variable s . So following the almost the same logicin the deduction of loss function of the variational auto-encoder, we now begin to deduce the variational lower bound ofthe scVI model. Firstly, the log likelihood conditioned on the batch ID is given by log p ψ ( n | s n ) for cell n (here andafter, we use ψ as the model parameter for the generation probability decoder model instead of θ above, since the θ parameter was used as the gene dispersion parameter), then we have the variational lower bound of the log-likelihoodconditioned on batch ID: L ( ψ, ϕ, n , s n ) = − D KL ( q ϕ ( z n , n | n , s n ) || p ψ ( z n , n | s n )) + (cid:69) q ϕ ( z n , n | n ,s n ) log p ψ ( n | z n , s n , n ) (60)The q ϕ ( z n , n | n , s n ) is a probability encoder which encode gene expression n into low dimensional hiddenvariable z n and the surrogate of the library size n conditioned on s n . The parameters ϕ are the collections of15 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE z n, Input Layer+1 z n, ... z n,d s n, s n, ... s n,B +1Hidden Layer h h ... h d o Output Layer o ... o G Fig 3. The one-layer MLP of representation of ƒ ( z n , s n ) with hidden neuron h = ReLU ( (cid:80) dj = W ( z ) ,j z n,j + (cid:80) Bj = W ( s ) ,j s n,j + b ) , = , . . . , d and output neuron o g = (cid:80) d j = W ( h ) g,j h j + b ( h ) g , g = , . . . , G , and the finalexpected frequency ρ gn == o g (cid:80) j e oj , g = , . . . , G , ƒ ( z n , s n ) = ρ n .16 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE weight of NN1, NN2, NN3, NN4 in figure 2. The encoder consists of two sub-encoder q ϕ ( z n , n | n , s n ) = q ϕ ( z n | n , s n ) q ϕ ( n | n , s n ) . The variational distribution q ϕ ( z n | n , s n ) is chosen to be Gaussian with a diagonalcovariance, with mean given by an encoder network NN1 applied to n , s n and diagonal deviation( the square root ofthe diagonal variance ) given by the encoder network NN2 applied to n , s n . The variation distribution q ϕ ( n | n , s n ) is chosen to be log-normal with scalar mean and variance, with mean and standard deviation (the square root of thevariance) given by neural network NN4, NN5, respectively. We apply the reparameterization trick on the variationaldistribution. z n ∼ q ϕ ( z n | n , s n ) ∼ N ( ƒ NN ,ϕ ( n , s n ) , diag { ( ƒ NN ,ϕ ( n , s n )) } ) z n = ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε z , ε z ∼ N ( , d ) n ∼ q ϕ ( n | n , s n ) ∼ log normal ( ƒ NN ,ϕ ( n , s n ) , ( ƒ NN ,ϕ ( n , s n )) ) n = exp ( ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε ) , ε ∼ N ( , ) (61)The prior p ψ ( z n , n | s n ) = p ( z n | s n ) p ( n | s n ) is chosen as a fixed distribution, where p ( z n | s n ) ∼ N ( , d ) and p ( n | s n ) ∼ log normal ( μ b n , σ b n ) where b n is the batch ID of cell n and μ b n : = (cid:80) ∈ batch bn log ( (cid:80) Gg = ,g ) (cid:80) ∈ batch bn σ b n : = ( (cid:80) ∈ batch bn log ( (cid:80) Gg = ,g ) − μ bn ) (cid:80) ∈ batch bn (62), i.e. the sample log mean and variance of the log of library size of cells in batch b n .To get the analytical expression of the KL divergence, we first calculate a simple example of D KL ( N ( m , σ ) || N ( m , σ )) . D KL ( N ( m , σ ) || N ( m , σ )) = (cid:82) ∞ − ∞ 12 πσ exp ( − ( − m ) σ ) log πσ exp ( − ( − m ) σ ) πσ exp ( − ( − m ) σ ) d = (cid:82) ∞ − ∞ 12 πσ exp ( − ( − m ) σ )[ log σ σ − ( − m ) σ + ( − m ) σ ] d = log σ σ − + σ +( m − m ) σ (63)So we have D KL ( q ϕ ( z n | n , s n ) || p ( z n )) = D KL ( N ( ƒ NN ,ϕ ( n , s n ) , diag { ( ƒ NN ,ϕ ( n , s n )) } ) || N ( , d ))= (cid:80) d = [ log [ ƒ NN ,ϕ ( n ,s n )] − + [ ƒ NN ,ϕ ( n ,s n )] +[ ƒ NN ,ϕ ( n ,s n )] ] (64)If Y ∼ log normal ( μ , σ ) with pdf p ( y ) and Y ∼ log normal ( μ , σ ) with pdf p ( y ) , then log Y ∼ N ( m , σ ) with pdf p ( e ) e and log Y ∼ N ( m , σ ) with pdf p ( e ) e . D KL ( log normal ( μ , σ ) || log normal ( μ , σ )) = (cid:82) ∞ y = p ( y ) log p ( y ) p ( y ) dy = (cid:82) ∞ = − ∞ p ( e ) e log p ( e ) e p ( e ) e d = D KL ( N ( m , σ ) || N ( m , σ ))= log σ σ − + σ +( m − m ) σ (65)So we have D KL ( q ϕ ( n | n , s n ) || p ( n ))) = D KL ( log normal ( ƒ NN ,ϕ ( n , s n ) , ( ƒ NN ,ϕ ( n , s n )) ) || log normal ( μ b n , σ b n ))= log σ bn ƒ NN ,ϕ ( n ,s n ) − + ( ƒ NN ,ϕ ( n ,s n )) +( ƒ NN ,ϕ ( n ,s n ) − m bn ) σ bn (66)Combining equation (64) and (66), we get D KL ( q ϕ ( z n , n | n , s n ) || p ψ ( z n , n )) = D KL ( q ϕ ( z n | n , s n ) q ϕ ( n | n , s n ) || p ( z n ) p ( n )))= D KL ( q ϕ ( z n | n , s n ) || p ( z n )) D KL ( q ϕ ( n | n , s n ) || p ( n )))= (cid:80) d = [ log [ ƒ NN ,ϕ ( n ,s n )] − + [ ƒ NN ,ϕ ( n ,s n )] +[ ƒ NN ,ϕ ( n ,s n )] ]+ log σ bn ƒ NN ,ϕ ( n ,s n ) − + ( ƒ NN ,ϕ ( n ,s n )) +( ƒ NN ,ϕ ( n ,s n ) − m bn ) σ bn (67)17 T UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
We now only need the reconstruction error term to get the final computable objective function. We use thereparameterization trick on the reconstruction error term and use sample average to estimate the expectation. (cid:69) q ϕ ( z n , n | n ,s n ) log p ψ ( n | z n , s n , n )= (cid:69) ε z ∼ N ( , d ) , ε ∼ N ( , ) log p ψ ( n | ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε z , s n , exp ( ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε )) ≈ K (cid:80) Kk = log p ψ ( n | ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε ( k ) z , s n , exp ( ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε ( k ) )) where ε ( k ) z ..d ∼ N ( , d ) , ε ( k ) ..d ∼ N ( , ) (68)Substituting equation (67) and (68) into equation (60), we get the computable variational lower bound of thelog-likelihood of cell n . L ( ψ, ϕ, n , s n ) ≈ ˜ L ( ψ, ϕ, n , s n ) : = (cid:80) d = [ log [ ƒ NN ,ϕ ( n ,s n )] − + [ ƒ NN ,ϕ ( n ,s n )] +[ ƒ NN ,ϕ ( n ,s n )] ]+ log σ bn ƒ NN ,ϕ ( n ,s n ) − + ( ƒ NN ,ϕ ( n ,s n )) +( ƒ NN ,ϕ ( n ,s n ) − m bn ) σ bn + K (cid:80) Kk = log p ψ ( n | z ( k ) n , s n , ( k ) n )= (cid:80) d = [ log [ ƒ NN ,ϕ ( n ,s n )] − + [ ƒ NN ,ϕ ( n ,s n )] +[ ƒ NN ,ϕ ( n ,s n )] ]+ log σ bn ƒ NN ,ϕ ( n ,s n ) − + ( ƒ NN ,ϕ ( n ,s n )) +( ƒ NN ,ϕ ( n ,s n ) − m bn ) σ bn + K (cid:80) Kk = (cid:80) Gg = { (cid:49) ng = [ S ( − ƒ gh ( z ( k ) n , s n ) + θ g log θ g θ g + ( k ) n ƒ g ( z ( k ) n ,s n ) − S ( − ƒ gh ( z ( k ) n , s n ))] (cid:49) ng > [ − ƒ gh ( z ( k ) n , s n ) − S ( − ƒ gh ( z ( k ) n , s n )) + θ g log θ g θ g + ( k ) n ƒ g ( z ( k ) n ,s n ) + ng log ng θ g + ( k ) n ƒ g ( z ( k ) n ,s n ) + log ( ng + θ g ) ( ng + ) ( θ g ) ] } where z ( k ) n = ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε ( k ) z , ( k ) n = exp ( ƒ NN ,ϕ ( n , s n ) + ƒ NN ,ϕ ( n , s n ) ε ( k ) ) , and ε ( k ) z ..d ∼ N ( , d ) , ε ( k ) ..d ∼ N ( , ) (69)where ψ is the collection of the parameters of the neural networks NN5, NN6 and θ , and ϕ is the collection of theparameters of the neural networks NN1, NN2, NN3, NN4. Note that the gene dispersion parameter θ ∈ (cid:82) G + is constantfor each gene g in figure 2, there are some variants of the choices of θ ,1. Gene dispersions are constant in each batch and each gene, θ ∈ (cid:82) G × B + , where B is the number of batches.2. If the gene expression data has been annotated. Gene dispersions can be chosen to a constant in each class andeach gene, θ ∈ (cid:82) G × C + , where C is the number of classes.3. Gene dispersions are chose to be specific for each cell and each gene, θ ∈ (cid:82) G × N + , where N is number of cellsin the data. In this case, θ can be models as a neural network θ = ƒ θ ( n , s n ) where θ is the parametersof the neural network.Now we can use the stochastic optimization method to maximize the objective function mx ψ,ϕ N (cid:80) n ˜ L ( ψ, ϕ, n , s n ) (70)where ˜ L ( ψ, ϕ, n , s n ) is defined in equation (69). Now, I have finished to introduce you the mathematical model of the scVI(Romain et al. (2018)), and you can get moredetails about the numerical experiments in scVI(Romain et al. (2018)). Since the sequencing data has accumulated ahuge amount, the neural networks based model(e.g. variational auto-encoder) gives us a power to hand such a hugeamount of data. However, the neural network is a black box, which means we can hard to gains insights from itsmillions parameters. This field need a large amount of exploration. Thank you for you reading, have a nice day!18 T
UTORIAL ON THE M ATHEMATICAL M ODEL OF S INGLE C ELL V ARIATIONAL I NFERENCE
References
Bengio, Y. (2009).
Learning Deep Architectures for AI .Doersch, C. (2016). Tutorial on Variational Autoencoders. arXiv e-prints , page arXiv:1606.05908.Goodfellow, I., Bengio, Y., and Courville, A. (2016).
Deep Learning . The MIT Press.Ioffe, S. and Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing InternalCovariate Shift. arXiv e-prints , page arXiv:1502.03167.Kingma, D. P. and Ba, J. (2014). Adam: A Method for Stochastic Optimization. arXiv e-prints , page arXiv:1412.6980.Kingma, D. P. and Welling, M. (2014). Auto-encoding variational bayes.Lecun, Y., Bengio, Y., and Hinton, G. E. (2015). Deep learning.
Nature , 521(7553):436–444.Romain, Lopez, Jeffrey, Regier, Michael, B., Cole, Michael, I., and Jordan (2018). Deep generative modeling forsingle-cell transcriptomics.
Nature Methods .Senior, A. W., Evans, R., Jumper, J., Kirkpatrick, J., Sifre, L., Green, T., Qin, C., Žídek, A., Nelson, A. W. R., Bridgland,A., Penedones, H., Petersen, S., Simonyan, K., Crossan, S., Kohli, P., Jones, D. T., Silver, D., Kavukcuoglu, K.,and Hassabis, D. (2020). Improved protein structure prediction using potentials from deep learning.
Nature ,577(7792):706–710.Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. (2014). Dropout: A simple way toprevent neural networks from overfitting.