Variational Inference with Tail-adaptive f-Divergence
VVariational Inference with Tail-adaptive f -Divergence Dilin WangUT Austin [email protected]
Hao Liu ∗ UESTC [email protected]
Qiang LiuUT Austin [email protected]
Abstract
Variational inference with α -divergences has been widely used in modern proba-bilistic machine learning. Compared to Kullback-Leibler (KL) divergence, a majoradvantage of using α -divergences (with positive α values) is their mass-covering property. However, estimating and optimizing α -divergences require to use im-portance sampling, which may have large or infinite variance due to heavy tailsof importance weights. In this paper, we propose a new class of tail-adaptive f -divergences that adaptively change the convex function f with the tail distributionof the importance weights, in a way that theoretically guarantees finite moments,while simultaneously achieving mass-covering properties. We test our method onBayesian neural networks, and apply it to improve a recent soft actor-critic (SAC)algorithm (Haarnoja et al., 2018) in deep reinforcement learning. Our results showthat our approach yields significant advantages compared with existing methodsbased on classical KL and α -divergences. Variational inference (VI) (e.g., Jordan et al., 1999; Wainwright et al., 2008) has been establishedas a powerful tool in modern probabilistic machine learning for approximating intractable posteriordistributions. The basic idea is to turn the approximation problem into an optimization problem, whichfinds the best approximation of an intractable distribution from a family of tractable distributions byminimizing a divergence objective function. Compared with Markov chain Monte Carlo (MCMC),which is known to be consistent but suffers from slow convergence, VI provides biased results butis often practically faster. Combined with techniques like stochastic optimization (Ranganath et al.,2014; Hoffman et al., 2013) and reparameterization trick (Kingma & Welling, 2014), VI has becomea major technical approach for advancing Bayesian deep learning, deep generative models and deepreinforcement learning (e.g., Kingma & Welling, 2014; Gal & Ghahramani, 2016; Levine, 2018).A key component of successful variational inference lies on choosing a proper divergence metric.Typically, closeness is defined by the KL divergence
KL( q || p ) (e.g., Jordan et al., 1999), where p is the intractable distribution of interest and q is a simpler distribution constructed to approximate p . However, VI with KL divergence often under-estimates the variance and may miss importantlocal modes of the true posterior (e.g., Christopher, 2016; Blei et al., 2017). To mitigate this issue,alternative metrics have been studied in the literature, a large portion of which are special cases of f -divergence (e.g., Csiszár & Shields, 2004): D f ( p || q ) = E x ∼ q (cid:20) f (cid:18) p ( x ) q ( x ) (cid:19) − f (1) (cid:21) , (1)where f : R + → R is any convex function. The most notable class of f -divergence that has beenexploited in VI is α -divergence, which takes f ( t ) = t α / ( α ( α − for α ∈ R \ { , } . By choosingdifferent α , we get a large number of well-known divergences as special cases, including the standard ∗ Work done at UT Austin32nd Conference on Neural Information Processing Systems (NeurIPS 2018), Montréal, Canada. a r X i v : . [ c s . L G ] S e p L divergence objective
KL( q || p ) ( α → ), the KL divergence with the reverse direction KL( p || q ) ( α → ) and the χ divergence ( α = 2 ). In particular, the use of general α -divergence in VI has beenwidely discussed (e.g., Minka et al., 2005; Hernández-Lobato et al., 2016; Li & Turner, 2016); thereverse KL divergence is used in expectation propagation (Minka, 2001; Opper & Winther, 2005),importance weighted auto-encoders (Burda et al., 2016), and the cross entropy method (De Boer et al.,2005); χ -divergence is exploited for VI (e.g., Dieng et al., 2017), but is more extensively studied inthe context of adaptive importance sampling (IS) (e.g., Cappé et al., 2008; Ryu & Boyd, 2014; Cotteret al., 2015), since it coincides with the variance of the IS estimator with q as the proposal.A major motivation of using α -divergence contributes to its mass-covering property: when α > ,the optimal approximation q tends to cover more modes of p , and hence better accounts for theuncertainty in p . Typically, larger values of α enforce stronger mass-covering properties. In practice,however, α divergence and its gradient need to be estimated empirically using samples from q . Usinglarge α values may cause high or infinite variance in the estimation because it involves estimating the α -th power of the density ratio p ( x ) /q ( x ) , which is likely distributed with a heavy or fat tail (e.g.,Resnick, 2007). In fact, when q is very different from p , the expectation of ratio ( p ( x ) /q ( x )) α can beinfinite (that is, α -divergence does not exist). This makes it problematic to use large α values, despitethe mass-covering property it promises. In addition, it is reasonable to expect that the optimal settingof α should vary across training processes and learning tasks. Therefore, it is desirable to designan approach to choose α adaptively and automatically as q changes during the training iterations,according to the distribution of the ratio p ( x ) /q ( x ) .Based on theoretical observations on f -divergence and fat-tailed distributions, we design a newclass of f -divergence which is tail-adaptive in that it uses different f functions according to the taildistribution of the density ratio p ( x ) /q ( x ) to simultaneously obtain stable empirical estimation anda strongest possible mass-covering property. This allows us to derive a new adaptive f -divergence-based variational inference by combining it with stochastic optimization and reparameterizationgradient estimates. Our main method (Algorithm 1) has a simple form, which replaces the f functionin (1) with a rank-based function of the empirical density ratio w = p ( x ) /q ( x ) at each gradientdescent step of q , whose variation depends on the distribution of w and does not explode regardlessthe tail of w .Empirically, we show that our method can better recover multiple modes for variational inference. Inaddition, we apply our method to improve a recent soft actor-critic (SAC) algorithm (Haarnoja et al.,2018) in reinforcement learning (RL), showing that our method can be used to optimize multi-modalloss functions in RL more efficiently. f -Divergence and Friends Given a distribution p ( x ) of interest, we want to approximate it with a simpler distribution from afamily { q θ ( x ) : θ ∈ Θ } , where θ is the variational parameter that we want to optimize. We approachthis problem by minimizing the f -divergence between q θ and p : min θ ∈ Θ (cid:26) D f ( p || q θ ) = E x ∼ q θ (cid:20) f (cid:18) p ( x ) q θ ( x ) (cid:19) − f (1) (cid:21) , (cid:27) (2)where f : R + → R is any twice differentiable convex function. It can be shown by Jensen’s inequalitythat D f ( p || q ) ≥ for any p and q . Further, if f ( t ) is strictly convex at t = 1 , then D f ( p || q ) = 0 implies p = q . The optimization in (2) can be solved approximately using stochastic optimization inpractice by approximating the expectation E x ∼ q θ [ · ] using samples drawing from q θ at each iteration.The f -divergence includes a large spectrum of important divergence measures. It includes KLdivergence in both directions, KL( q || p ) = E x ∼ q (cid:20) log q ( x ) p ( x ) (cid:21) , KL( p || q ) = E x ∼ q (cid:20) p ( x ) q ( x ) log p ( x ) q ( x ) (cid:21) , (3)which correspond to f ( t ) = − log t and f ( t ) = t log t , respectively. KL( q || p ) is the typicalobjective function used in variational inference; the reversed direction KL( p || q ) is also used invarious settings (e.g., Minka, 2001; Opper & Winther, 2005; De Boer et al., 2005; Burda et al., 2016).2ore generally, f -divergence includes the class of α -divergence, which takes f α ( t ) = t α / ( α ( α − , α ∈ R \ { , } and hence D f α ( p || q ) = 1 α ( α − E x ∼ q (cid:20)(cid:18) p ( x ) q ( x ) (cid:19) α − (cid:21) . (4)One can show that KL( q || p ) and KL( p || q ) are the limits of D f α ( q || p ) when α → and α → ,respectively. Further, one obtain Helinger distance and χ -divergence as α = 1 / and α = 2 ,respectively. In particular, χ -divergence ( α = 2 ) plays an important role in adaptive importancesampling, because it equals the variance of the importance weight w = p ( x ) /q ( x ) and minimizing χ -divergence corresponds to finding an optimal importance sampling proposal. α -Divergence and Fat Tails A major motivation of using α divergences as the objective function for approximate inference is their mass-covering property (also known as the zero-avoiding behavior). This is because α -divergence isproportional to the α -th moment of the density ratio p ( x ) /q ( x ) . When α is positive and large, largevalues of p ( x ) /q ( x ) are strongly penalized, preventing the case of q ( x ) (cid:28) p ( x ) . In fact, whenever D f α ( p || q ) < ∞ , we have p ( x ) > imply q ( x ) > . This means that the probability mass and localmodes of p are taken into account in q properly.Note that the case when α ≤ exhibits the opposite property, that is, p ( x ) = 0 must imply q ( x ) = 0 to make D f α ( q || p ) finite when α ≤ ; this includes the typical KL divergence KL( q || p ) ( α = 0 ),which is often criticized for its tendency to under-estimate the uncertainty.Typically, using larger values of α enforces stronger mass-covering properties. In practice, however,larger values of α also increase the variance of the empirical estimators, making it highly challengingto optimize. In fact, the expectation in (4) may not even exist when α is too large. This is because thedensity ratio w := p ( x ) /q ( x ) often has a fat-tailed distribution.A non-negative random variable w is called fat-tailed (e.g., Resnick, 2007) if its tail probability ¯ F w ( t ) := Pr( w ≥ t ) is asymptotically equivalent to t − α ∗ as t → + ∞ for some finite positivenumber α ∗ (denoted by ¯ F w ( t ) ∼ t − α ∗ ), which means that ¯ F w ( t ) = t − α ∗ L ( t ) , where L is a slowly varying function that satisfies lim t → + ∞ L ( ct ) /L ( t ) = 1 for any c > . Here α ∗ determines the fatness of the tail and is called the tail index of w . For a fat-tailed distributionwith index α ∗ , its α -th moment exists only if α < α ∗ , that is, E [ w α ] < ∞ iff α < α ∗ . It turns outthe density ratio w := p ( x ) /q ( x ) , when x ∼ q , tends to have a fat-tailed distribution when q is morepeaked than p . The example below illustrates this with simple Gaussian distributions. Example 3.1.
Assume p ( x ) = N ( x ; 0 , σ p ) and q ( x ) = N ( x ; 0 , σ q ) . Let x ∼ q and w = p ( x ) /q ( x ) the density ratio. If σ p > σ q , then w has a fat-tailed distribution with index α ∗ = σ p / ( σ p − σ q ) . On the other hand, if σ p ≤ σ q , then w is bounded and not fat-tailed (effectively, α ∗ = + ∞ ). By the definition above, if the importance weight w = p ( x ) /q ( x ) has a tail index α ∗ , the α -divergence D f α ( p || q ) exists only if α < α ∗ . Although it is desirable to use α -divergence with large values of α as VI objective function, it is important to keep α smaller than α ∗ to ensure that the objective andgradient are well defined. The problem, however, is that the tail index α ∗ is unknown in practice, andmay change dramatically (e.g., even from finite to infinite) as q is updated during the optimizationprocess. This makes it suboptimal to use a pre-fixed α value. One potential way to address thisproblem is to estimate the tail index α ∗ empirically at each iteration using a tail index estimator (e.g.,Hill et al., 1975; Vehtari et al., 2015). Unfortunately, tail index estimation is often challenging andrequires a large number of samples. The algorithm may become unstable if α ∗ is over-estimated. f -Divergence In this work, we address the aforementioned problem by designing a generalization of f -divergencein which f adaptively changes with p and q , in a way that always guarantees the existence of the Fat-tailed distributions is a sub-class of heavy-tailed distributions, which are distributions whose tailprobabilities decay slower than exponential functions, that is, lim t → + ∞ exp( λt ) ¯ F w ( t ) = ∞ for all λ > . α -divergence with α = α ∗ .One challenge of designing such adaptive f is that the convex constraint over function f is difficultto express computationally. Our first key observation is that it is easier to specify a convex function f through its second order derivative f (cid:48)(cid:48) , which can be any non-negative function. It turns out f -divergence, as well as its gradient, can be conveniently expressed using f (cid:48)(cid:48) , without explicitlydefining the original f . Proposition 4.1.
1) Any twice differentiable convex function f : R + ∪ { } → R with finite f (0) canbe decomposed into linear and nonlinear components as follows f ( t ) = ( at + b ) + (cid:90) ∞ ( t − µ ) + h ( µ ) dµ, (5) where h is a non-negative function, ( t ) + = max(0 , t ) , and a , b ∈ R . In this case, h = f (cid:48)(cid:48) ( t ) ,a = f (cid:48) (0) and b = f (0) . Conversely, any non-negative function h and a, b ∈ R specifies a convexfunction.2) This allows us to derive an alternative representation of f -divergence: D f ( p || q ) = (cid:90) ∞ f (cid:48)(cid:48) ( µ ) E x ∼ q (cid:34)(cid:18) p ( x ) q ( x ) − µ (cid:19) + (cid:35) dµ − c, (6) where c := (cid:82) f (cid:48)(cid:48) ( µ )(1 − µ ) dµ = f (1) − f (0) − f (cid:48) (0) is a constant.Proof. If f ( t ) = ( at + b ) + (cid:82) ∞ ( t − µ ) + h ( µ ) dµ , calculation shows f (cid:48) ( t ) = a + (cid:90) t h ( µ ) dµ, f (cid:48)(cid:48) ( t ) = h ( t ) . Therefore, f is convex iff h is non-negative. See Appendix for the complete proof.Eq (6) suggests that all f -divergences are conical combinations of a set of special f -divergences ofform E x ∼ q [( p ( x ) /q ( x ) − µ ) + − f (1)] with f ( t ) = ( t − µ ) + . Also, every f -divergence is completelyspecified by the Hessian f (cid:48)(cid:48) , meaning that adding f with any linear function at + b does not change D f ( p || q ) . Such integral representation of f -divergence is not new; see e.g., Feldman & Osterreicher(1989); Osterreicher (2003); Liese & Vajda (2006); Reid & Williamson (2011); Sason (2018).For the purpose of minimizing D f ( p || q θ ) ( θ ∈ Θ ) in variational inference, we are more concernedwith calculating the gradient, rather than the f -divergence itself. It turns out the gradient of D f ( p || q θ ) is also directly related to Hessian f (cid:48)(cid:48) in a simple way. Proposition 4.2.
1) Assume log q θ ( x ) is differentiable w.r.t. θ , and f is a differentiable convexfunction. For f -divergence defined in (2) , we have ∇ θ D f ( p || q θ ) = − E x ∼ q θ (cid:20) ρ f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ log q θ ( x ) (cid:21) , (7) where ρ f ( t ) = f (cid:48) ( t ) t − f ( t ) (equivalently, ρ (cid:48) f ( t ) = f (cid:48)(cid:48) ( t ) t if f is twice differentiable).2) Assume x ∼ q θ is generated by x = g θ ( ξ ) where ξ ∼ q is a random seed and g θ is a function thatis differentiable w.r.t. θ . Assume f is twice differentiable and ∇ x log( p ( x ) /q θ ( x )) exists. We have ∇ θ D f ( p || q θ ) = − E x = g θ ( ξ ) ,ξ ∼ q (cid:20) γ f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ g θ ( ξ ) ∇ x log( p ( x ) /q θ ( x )) (cid:21) , (8) where γ f ( t ) = ρ (cid:48) f ( t ) t = f (cid:48)(cid:48) ( t ) t . The result above shows that the gradient of f -divergence depends on f through ρ f or γ f . Taking α -divergence ( α / ∈ { , } ) as example, we have f ( t ) = t α / ( α ( α − , ρ f ( t ) = t α /α, γ f ( t ) = t α , t α . For KL( q || p ) , we have f ( t ) = − log t ,yielding ρ f ( t ) = log t − and γ f ( t ) = 1 ; for KL( p || q ) , we have f ( t ) = t log t , yielding ρ f ( t ) = t and γ f ( t ) = t .The formulas in (7) and (8) are called the score-function gradient and reparameterization gra-dient (Kingma & Welling, 2014), respectively. Both equal the gradient in expectation, but arecomputationally different and yield empirical estimators with different variances. In particular, thescore-function gradient in (7) is “gradient-free” in that it does not require calculating the gradient ofthe distribution p ( x ) of interest, while (8) is “gradient-based” in that it involves ∇ x log p ( x ) . It hasbeen shown that optimizing with reparameterization gradients tend to give better empirical resultsbecause it leverages the gradient information ∇ x log p ( x ) , and yields a lower variance estimator forthe gradient (e.g., Kingma & Welling, 2014).Our key observation is that we can directly specify f through any increasing function ρ f , or non-negative function γ f in the gradient estimators, without explicitly defining f . Proposition 4.3.
Assume f : R + → R is convex and twice differentiable, then1) ρ f in (7) is a monotonically increasing function on R + . In addition, for any differentiableincreasing function ρ , there exists a convex function f such that ρ f = ρ ;2) γ f in (8) is non-negative on R + , that is, γ f ( t ) ≥ , ∀ t ∈ R + . In addition, for any non-negativefunction γ , there exists a convex function f such that γ f = γ ;3) if ρ (cid:48) f ( t ) is strictly increasing at t = 1 (i.e., ρ (cid:48) f (1) > ), or γ f ( t ) is strictly positive at t = 1 (i.e., γ f (1) > ), then D f ( p || q ) = 0 implies p = q .Proof. Because f is convex ( f (cid:48)(cid:48) ( t ) ≥ ), we have γ f ( t ) = f (cid:48)(cid:48) ( t ) t ≥ and ρ (cid:48) f ( t ) = f (cid:48)(cid:48) ( t ) t ≥ on t ∈ R + , that is, γ f is non-negative and ρ f is increasing on R + . If ρ t is strictly increasing (or γ f isstrictly positive) at t = 1 , we have f is strictly convex at t = 1 , which guarantees D f ( p || q ) = 0 imply p = q .For non-negative function γ ( t ) (or increasing function ρ ( t ) ) on R + , any convex function f whosesecond-order derivative equals γ ( t ) /t (or ρ (cid:48) f ( t ) /t ) satisfies γ f = γ (resp. ρ f = ρ ). f -Divergence with Inverse Tail Probability The results above show that it is sufficient to find an increasing function ρ f , or a non-negative function γ f to obtain adaptive f -divergence with computable gradients. In order to make the f -divergence“safe”, we need to find ρ f or γ f that adaptively depends on p and q such that the expectation in (7)and (8) always exists. Because the magnitude of ∇ θ log q θ ( x ) , ∇ x log( p ( x ) /q θ ( x )) and ∇ θ g θ ( ξ ) arerelatively small compared with the ratio p ( x ) /q ( x ) , we can mainly consider designing function ρ (or γ ) such that they yield finite expectation E x ∼ q [ ρ ( p ( x ) /q ( x ))] < ∞ ; meanwhile, we should also keepthe function large, preferably with the same magnitude as t α ∗ , to provide a strong mode-coveringproperty. As it turns out, the inverse of the tail probability naturally achieves all these goals. Proposition 5.1.
For any random variable w with tail distribution ¯ F w ( t ) := Pr( w ≥ t ) and tailindex α ∗ , we have E [ ¯ F w ( w ) β ] < ∞ , for any β > − . Also, we have ¯ F w ( t ) β ∼ t − βα ∗ , and ¯ F w ( t ) β is always non-negative and monotonically increasingwhen β < .Proof. Simply note that E [ ¯ F w ( w ) β ] = (cid:82) ¯ F w ( t ) β d ¯ F β ( t ) = (cid:82) t β dt, which is finite only when β > − . The non-negativity and monotonicity of ¯ F w ( t ) β are obvious. ¯ F w ( t ) β ∼ t − βα ∗ directlyfollows the definition of the tail index.This motivates us to use ¯ F w ( t ) β to define ρ f or γ f , yielding two versions of “safe” tail-adaptive f divergences. Note that here f is defined implicitly through ρ f or γ f . Although it is possible to derivethe corresponding f and D f ( p || q ) , there is no computational need to do so, since optimizing theobjective function only requires calculating the gradient, which is defined by ρ f or γ f .5 lgorithm 1 Variational Inference with Tail-adaptive f -Divergence (with Reparameterization Gradi-ent)Goal: Find the best approximation of p ( x ) from { q θ : θ ∈ Θ } . Assume x ∼ q θ is generated by x = g θ ( ξ ) where ξ is a random sample from noise distribution q .Initialize θ , set an index β (e.g., β = − ). for iteration do Draw { x i } ni =1 ∼ q θ , generated by x i = g θ ( ξ i ) .Let w i = p ( x i ) /q θ ( x i ) , ˆ¯ F w ( t ) = (cid:80) nj =1 I ( w j ≥ t ) /n , and set γ i = ˆ¯ F w ( w i ) β . Update θ ← θ + (cid:15) ∆ θ , with (cid:15) is step size, and ∆ θ = 1 z γ n (cid:88) i =1 [ γ i ∇ θ g θ ( ξ i ) ∇ x log( p ( x i ) /q θ ( x i ))] , where z γ = n (cid:88) i =1 γ i . end for In practice, the explicit form of ¯ F w ( t ) β is unknown. We can approximate it based on empirical datadrawn from q . Let { x i } be drawn from q and w i = p ( x i ) /q ( x i ) , then we can approximate the tailprobability with ˆ¯ F w ( t ) = n (cid:80) ni =1 I ( w i ≥ t ) . Intuitively, this corresponds to assigning each datapoint a weight according to the rank of its density ratio in the population. Substituting the empiricaltail probability into the reparametrization gradient formula in (8) and running a gradient descent withstochastic approximation yields our main algorithm shown in Algorithm 1. The version with thescore-function gradient is similar and is shown in Algorithm 2 in the Appendix. Both algorithms canbe viewed as minimizing the implicitly constructed adaptive f -divergences, but correspond to usingdifferent f .Compared with typical VI with reparameterized gradients, our method assigns a weight ρ i =ˆ¯ F w ( w i ) β , which is proportional w βi where w i denotes the rank of data w i in the population { w i } .When taking − < β < , this allows us to penalize places with high ratio p ( x ) /q ( x ) , but avoidto be overly aggressive. In practice, we find that simply taking β = − almost always yields thebest empirical performance (despite needing β > − theoretically). By comparison, minimizing theclassical α -divergence would have a weight of w αi ; if α is too large, the weight of a single data pointbecomes dominant, making gradient estimate unstable. In this section, we evaluate our adaptive f -divergence with different models. We use reparam-eterization gradients as default since they have smaller variances (Kingma & Welling, 2014)and normally yield better performance than score function gradients. Our code is available at https://github.com/dilinwang820/adaptive-f-divergence . We first illustrate the approximation quality of our proposed adaptive f -divergence on Gaus-sian mixture models. In this case, we set our target distribution to be a Gaussian mixture p ( x ) = (cid:80) ki =1 1 k N ( x ; ν i , , for x ∈ R d , where the elements of each mean vector ν i is drawnfrom uniform([ − s, s ]) . Here s can be viewed as controlling the Gaussianity of the target distribution: p reduces to standard Gaussian distribution when s = 0 and is increasingly multi-modal when s increases. We fix the number of components to be k = 10 , and initialize the proposal distributionusing q ( x ) = (cid:80) i =1 w i N ( x ; µ i , σ i ) , where (cid:80) i =1 w i = 1 .We evaluate the mode-seeking ability of how q covers the modes of p using a “mode-shift distance” dist ( p, q ) := (cid:80) i =1 min j || ν i − µ j || / , which is the average distance of each mode in p to itsnearest mode in distribution q . The model is optimized using Adagrad with a constant learning rate . . We use a minibatch of size 256 to approximate the gradient in each iteration. We train themodel for , iterations. To learn the component weights, we apply the Gumble-Softmax trick(Jang et al., 2017; Maddison et al., 2017) with a temperature of . . Figure 1 shows the result whenwe obtain random mixtures p using s = 5 , when the dimension d of x equals and , respectively.6a) Mode-shift distance (b) Mean (c) Variance A vg . d i s t a n ce -2 -1 0 1 20.55 L og10 M S E -2 -1 0 1 2-2-101 L og10 M S E -2 -1 0 1 2-2-1012 Adaptive(dim=2)Adaptive(dim=10)Alpha(dim=2)Alpha(dim=10) choice of α/β choice of α/β choice of α/β
Figure 1: (a) plots the mode-shift distance between p and q ; (b-c) show the MSE of mean and variance betweenthe true posterior p and our approximation q , respectively. All results are averaged over 10 random trials. (a) Mode-shift distance (b) Mean (c) Variance A vg . d i s t a n ce L og10 M S E L og10 M S E Adaptive(beta=-1)Alpha(alpha=0)Alpha(alpha=0.5)Alpha(alpha=1.0)
Non-Gaussianity s Non-Gaussianity s Non-Gaussianity s Figure 2:
Results on randomly generated Gaussian mixture models. (a) plots the average mode-shift distance;(b-c) show the MSE of mean and variance. All results are averaged over 10 random trials.
We can see that when the dimension is low ( = 2 ), all algorithms perform similarly well. However,as we increase the dimension to 10, our approach with tail-adaptive f -divergence achieves the bestperformance.To examine the performance of variational approximation more closely, we show in Figure 2 theaverage mode-shift distance and the MSE of the estimated mean and variance as we gradually increasethe non-Gaussianality of p ( x ) by changing s . We fix the dimension to . We can see from Figure 2that when p is close to Gaussian (small s ), all algorithms perform well; when p is highly non-Gaussian(large s ), we find that our approach with adaptive weights significantly outperform other baselines. We evaluate our approach on Bayesian neural network regression tasks. The datasets are collectedfrom the UCI dataset repository . Similarly to Li & Turner (2016), we use a single-layer neuralnetwork with 50 hidden units and ReLU activation, except that we take 100 hidden units for theProtein and Year dataset which are relatively large. We use a fully factorized Gaussian approximationto the true posterior and Gaussian prior for neural network weights. All datasets are randomlypartitioned into for training and for testing. We use Adam optimizer (Kingma & Ba, 2015)with a constant learning rate of . . The gradient is approximated by n = 100 draws of x i ∼ q θ and a minibatch of size 32 from the training data points. All results are averaged over 20 randompartitions, except for Protein and Year, on which trials are repeated.We summarize the average RMSE and test log-likelihood with standard error in Table 1. We compareour algorithm with α = 0 (KL divergence) and α = 0 . , which are reported as the best for this taskin Li & Turner (2016). More comparisons with different choices of α are provided in the appendix.We can see from Table 1 that our approach takes advantage of an adaptive choice of f -divergenceand achieves the best performance for both test RMSE and test log-likelihood in most of the cases. We now demonstrate an application of our method in reinforcement learning, applying it as an innerloop to improve a recent soft actor-critic(SAC) algorithm (Haarnoja et al., 2018). See more related https://archive.ics.uci.edu/ml/datasets.html verage Test RMSE Average Test Log-likelihooddataset β = − . α = 0 . α = 0 . β = − . α = 0 . α = 0 . Boston . ± . . ± . . ± .
177 2 . ± .
171 2 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Concrete . ± . . ± . . ± .
115 5 . ± .
124 5 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Energy . ± . . ± . . ± .
034 1 . ± .
029 1 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Kin8nm . ± .
001 0 . ± .
001 0 . ± . . ± . . ± .
001 1 . ± .
001 1 . ± .
001 1 . ± . . ± . . ± . Naval . ± . . ± . . ± .
000 0 . ± . . ± . . ± .
000 0 . ± .
000 5 . ± . . ± . . ± .
000 5 . ± .
000 4 . ± . Combined . ± . . ± . . ± .
032 4 . ± .
034 4 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Wine . ± .
008 0 . ± . . ± . . ± .
007 0 . ± . . ± . . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Yacht . ± . . ± . . ± .
059 0 . ± .
056 1 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Protein . ± . . ± . . ± .
019 4 . ± .
026 4 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Year . ± . . ± . . ± .
037 8 . ± .
036 8 . ± . − . ± . − . ± . − . ± . − . ± . − . ± . Table 1:
Average test RMSE and log-likelihood for Bayesian neural regression. applications at Belousov & Peters (e.g. 2017, 2019, 2018). We start with a brief introduction of thebackground of SAC and then test our method in MuJoCo environments. Reinforcement Learning Background
Reinforcement learning considers the problem of findingoptimal policies for agents that interact with uncertain environments to maximize the long-termcumulative reward. This is formally framed as a Markov decision process in which agents iterativelytake actions a based on observable states s , and receive a reward signal r ( s, a ) immediately followingthe action a performed at state s . The change of the states is governed by an unknown environmentaldynamic defined by a transition probability T ( s (cid:48) | s, a ) . The agent’s action a is selected by a conditionalprobability distribution π ( a | s ) called policy. In policy gradient methods, we consider a set ofcandidate policies π θ ( a | s ) parameterized by θ and obtain the optimal policy by maximizing theexpected cumulative reward J ( θ ) = E s ∼ d π ,a ∼ π ( a | s ) [ r ( s, a )] , where d π ( s ) = (cid:80) ∞ t =1 γ t − Pr( s t = s ) is the unnormalized discounted state visitation distributionwith discount factor γ ∈ (0 , . Soft Actor-Critic (SAC) is an off-policy optimization algorithm derived based on maximizing theexpected reward with an entropy regularization. It iteratively updates a Q-function Q ( a, s ) , whichpredicts that cumulative reward of taking action a under state s , as well as a policy π ( a | s ) whichselects action a to maximize the expected value of Q ( s, a ) . The update rule of Q ( s, a ) is based on avariant of Q-learning that matches the Bellman equation, whose detail can be found in Haarnoja et al.(2018). At each iteration of SAC, the update of policy π is achieved by minimizing KL divergence π new = arg min π E s ∼ d [KL( π ( ·| s ) || p Q ( ·| s ))] , (9) p Q ( a | s ) = exp (cid:18) τ ( Q ( a, s ) − V ( s )) (cid:19) , (10)where τ is a temperature parameter, and V ( s ) = τ log (cid:82) a exp( Q ( a, s ) /τ ) , serving as a normalizationconstant here, is a soft-version of value function and is also iteratively updated in SAC. Here, d ( s ) isa visitation distribution on states s , which is taken to be the empirical distribution of the states in thecurrent replay buffer in SAC. We can see that (9) can be viewed as a variational inference problem onconditional distribution p Q ( a | s ) , with the typical KL objective function ( α = 0 ). SAC With Tail-adaptive f -Divergence To apply f -divergence, we first rewrite (9) to transform theconditional distributions to joint distributions. We define joint distribution p Q ( a, s ) = exp(( Q ( a, s ) − V ( s )) /τ ) d ( s ) and q π ( a, s ) = π ( a | s ) d ( s ) , then we can show that E s ∼ d [KL( π ( ·| s ) || p Q ( ·| s ))] =KL( q π || p Q ) . This motivates us to extend the objective function in (9) to more general f -divergences, D f ( p Q || q π ) = E s ∼ d E a | s ∼ π (cid:20) f (cid:18) exp(( Q ( a, s ) − V ( s )) /τ ) π ( a | s ) (cid:19) − f (1) (cid:21) . nt HalfCheetah Humanoid(rllab) A v e r a g e R e w a r d
0M 2M 4M 6M 8M 10M −
0M 2M 4M 6M 8M 10M02000400060008000100001200014000
0M 2M 4M 6M 8M 10M02505007501000
Walker Hopper Swimmer(rllab) A v e r a g e R e w a r d
0M 1M 2M 3M 4M 5M0500100015002000250030003500400045005000 α =0.0 α =0.5 α =max β =-1.0 Figure 3:
Soft Actor Critic (SAC) with policy updated by Algorithm 1 with β = − , or α -divergence VI withdifferent α ( α = 0 corresponds to the original SAC). The reparameterization gradient estimator is used in all thecases. In the legend, “ α = max ” denotes setting α = + ∞ in α -divergence. By using our tail-adaptive f -divergence, we can readily apply our Algorithm 1 (or Algorithm 2 inthe Appendix) to update π in SAC, allowing us to obtain π that counts the multi-modality of Q ( a, s ) more efficiently. Note that the standard α -divergence with a fixed α also yields a new variant of SACthat is not yet studied in the literature. Empirical Results
We follow the experimental setup of Haarnoja et al. (2018). The policy π , thevalue function V ( s ) , and the Q-function Q ( s, a ) are neural networks with two fully-connected layersof 128 hidden units each. We use Adam (Kingma & Ba, 2015) with a constant learning rate of . for optimization. The size of the replay buffer for HalfCheetah is , and we fix the size to onother environments in a way similar to Haarnoja et al. (2018).We compare with the original SAC ( α = 0 ), and also other α -divergences, such as α = 0 . and α = ∞ (the α = max suggested in Li & Turner (2016)). Figure 3 summarizes the total averagereward of evaluation rollouts during training on various MuJoCo environments. For non-negative α settings, methods with larger α give higher average reward than the original KL-based SAC in mostof the cases. Overall, our adaptive f -divergence substantially outperforms all other α -divergences onall of the benchmark tasks in terms of the final performance, and learns faster than all the baselines inmost environments. We find that our improvement is especially significant on high dimensional andcomplex environments like Ant and Humanoid. In this paper, we present a new class of tail-adaptive f -divergence and exploit its application invariational inference and reinforcement learning. Compared to classic α -divergence, our approachguarantees finite moments of the density ratio and provides more stable importance weights andgradient estimates. Empirical results on Bayesian neural networks and reinforcement learning indicatethat our approach outperforms standard α -divergence, especially for high dimensional multi-modaldistribution. Acknowledgement
This work is supported in part by NSF CRII 1830161. We would like to acknowledge Google Cloudfor their support. 9 eferences
Belousov, Boris and Peters, Jan. f-divergence constrained policy improvement. arXiv preprintarXiv:1801.00056 , 2017.Belousov, Boris and Peters, Jan. Mean squared advantage minimization as a consequence of entropicpolicy improvement regularization.
European Workshop on Reinforcement Learning , 2018.Belousov, Boris and Peters, Jan. Entropic regularization of markov decision processes.
Entropy , 21(7):674, 2019.Blei, David M, Kucukelbir, Alp, and McAuliffe, Jon D. Variational inference: A review forstatisticians.
Journal of the American Statistical Association , 112(518):859–877, 2017.Burda, Yuri, Grosse, Roger, and Salakhutdinov, Ruslan. Importance weighted autoencoders.
Interna-tional Conference on Learning Representations (ICLR) , 2016.Cappé, Olivier, Douc, Randal, Guillin, Arnaud, Marin, Jean-Michel, and Robert, Christian P. Adaptiveimportance sampling in general mixture classes.
Statistics and Computing , 18(4):447–459, 2008.Christopher, M Bishop.
Pattern Recognition and Machine Learning . Springer-Verlag New York,2016.Cotter, Colin, Cotter, Simon, and Russell, Paul. Parallel adaptive importance sampling. arXiv preprintarXiv:1508.01132 , 2015.Csiszár, I. and Shields, P.C. Information theory and statistics: A tutorial.
Foundations and Trends inCommunications and Information Theory , 1(4):417–528, 2004.De Boer, Pieter-Tjerk, Kroese, Dirk P, Mannor, Shie, and Rubinstein, Reuven Y. A tutorial on thecross-entropy method.
Annals of operations research , 134(1):19–67, 2005.Dieng, Adji Bousso, Tran, Dustin, Ranganath, Rajesh, Paisley, John, and Blei, David. Variationalinference via χ upper bound minimization. In Advances in Neural Information Processing Systems(NIPS) , pp. 2732–2741, 2017.Feldman, Dorian and Osterreicher, Ferdinand. A note on f -divergences. Studia Sci. \ Math. \ Hungar. ,24:191–200, 1989.Gal, Yarin and Ghahramani, Zoubin. Dropout as a Bayesian approximation: Representing modeluncertainty in deep learning. In international conference on machine learning (ICML) , pp. 1050–1059, 2016.Haarnoja, Tuomas, Zhou, Aurick, Abbeel, Pieter, and Levine, Sergey. Soft actor-critic: Off-policymaximum entropy deep reinforcement learning with a stochastic actor.
International Conferenceon Machine Learning (ICML) , 2018.Hernández-Lobato, José Miguel, Li, Yingzhen, Rowland, Mark, Hernández-Lobato, Daniel, Bui,Thang, and Turner, Richard Eric. Black-box α -divergence minimization. International Conferenceon Machine Learning (ICML) , 2016.Hill, Bruce M et al. A simple general approach to inference about the tail of a distribution.
Theannals of statistics , 3(5):1163–1174, 1975.Hoffman, Matthew D, Blei, David M, Wang, Chong, and Paisley, John. Stochastic variationalinference.
The Journal of Machine Learning Research , 14(1):1303–1347, 2013.Jang, Eric, Gu, Shixiang, and Poole, Ben. Categorical reparameterization with Gumbel-softmax.
International Conference on Learning Representations (ICLR) , 2017.Jordan, Michael I, Ghahramani, Zoubin, Jaakkola, Tommi S, and Saul, Lawrence K. An introductionto variational methods for graphical models.
Machine learning , 37(2):183–233, 1999.Kingma, Diederik P and Ba, Jimmy. Adam: A method for stochastic optimization.
InternationalConference on Learning Representations (ICLR) , 2015.Kingma, Diederik P and Welling, Max. Auto-encoding variational Bayes.
International Conferenceon Learning Representations (ICLR) , 2014.Levine, Sergey. Reinforcement learning and control as probabilistic inference: Tutorial and review. arXiv preprint arXiv:1805.00909 , 2018.Li, Yingzhen and Turner, Richard E. Rényi divergence variational inference. In
Advances in NeuralInformation Processing Systems (NIPS) , pp. 1073–1081, 2016.10iese, Friedrich and Vajda, Igor. On divergences and informations in statistics and information theory.
IEEE Transactions on Information Theory , 52(10):4394–4412, 2006.Maddison, Chris J, Mnih, Andriy, and Teh, Yee Whye. The concrete distribution: A continuousrelaxation of discrete random variables.
International Conference on Learning Representations(ICLR) , 2017.Minka, Thomas P. Expectation propagation for approximate Bayesian inference. In
Proceedings ofthe Seventeenth conference on Uncertainty in artificial intelligence (UAI) , pp. 362–369. MorganKaufmann Publishers Inc., 2001.Minka, Tom et al. Divergence measures and message passing. Technical report, Microsoft Research,2005.Opper, Manfred and Winther, Ole. Expectation consistent approximate inference.
Journal of MachineLearning Research , 6(Dec):2177–2204, 2005.Osterreicher, Ferdinand. f-divergences—representation theorem and metrizability.
Inst. Math., Univ.Salzburg, Salzburg, Austria , 2003.Ranganath, Rajesh, Gerrish, Sean, and Blei, David. Black box variational inference. In
ArtificialIntelligence and Statistics , pp. 814–822, 2014.Reid, Mark D and Williamson, Robert C. Information, divergence and risk for binary experiments.
Journal of Machine Learning Research , 12(Mar):731–817, 2011.Resnick, Sidney I.
Heavy-tail phenomena: probabilistic and statistical modeling . Springer Science& Business Media, 2007.Ryu, Ernest K and Boyd, Stephen P. Adaptive importance sampling via stochastic convex program-ming. arXiv preprint arXiv:1412.4845 , 2014.Sason, Igal. On f-divergences: Integral representations, local behavior, and inequalities.
Entropy , 20(5):383, 2018.Vehtari, Aki, Gelman, Andrew, and Gabry, Jonah. Pareto smoothed importance sampling. arXivpreprint arXiv:1507.02646 , 2015.Wainwright, Martin J, Jordan, Michael I, et al. Graphical models, exponential families, and variationalinference.
Foundations and Trends R (cid:13) in Machine Learning , 1(1–2):1–305, 2008.11 Proof of Proposition 4.1
Proof.
Taking h ( t ) = f (cid:48)(cid:48) ( t ) , a = f (cid:48) (0) and b = f (0) in Eq (5), we have ( f (cid:48) (0) t + f (0)) + (cid:90) ∞ ( t − µ ) + h ( µ ) dµ = ( f (cid:48) (0) t + f (0)) + (cid:90) t ( t − µ ) f (cid:48)(cid:48) ( µ ) dµ = ( f (cid:48) (0) t + f (0)) + ( t − µ ) f (cid:48) ( µ ) (cid:12)(cid:12)(cid:12)(cid:12) tµ =0 + (cid:90) t f (cid:48) ( µ ) dµ (integration by parts) = f (0) + (cid:90) t f (cid:48) ( µ ) dµ = f ( t ) . Conversely, if f ( t ) = ( at + b ) + (cid:82) ∞ ( t − µ ) + h ( µ ) dµ , calculation shows f (cid:48) ( t ) = a + (cid:90) t h ( µ ) dµ, f (cid:48)(cid:48) ( t ) = h ( t ) . Therefore, f is convex if h is non-negative.To prove Eq. (6), we substitute f ( t ) = f (cid:48) (0) t + f (0) + (cid:82) ∞ ( t − µ ) + f (cid:48)(cid:48) ( µ ) dµ into the definition of f -divergence, D f ( p || q ) = E q (cid:20) f (cid:18) p ( x ) q ( x ) (cid:19) − f (1) (cid:21) = E q (cid:20) f (cid:48) (0) p ( x ) q ( x ) + f (0) + (cid:90) ∞ ( p ( x ) /q ( x ) − µ ) + f (cid:48)(cid:48) ( µ ) dµ − f (1) (cid:21) = [ f (cid:48) (0) + f (0) − f (1)] + (cid:90) ∞ E q (cid:34)(cid:18) p ( x ) q ( x ) − µ (cid:19) + (cid:35) f (cid:48)(cid:48) ( µ ) dµ. This completes the proof.
B Proof of Proposition 4.2
Proof.
By chain rule and the “score-function trick” ∇ θ q θ ( x ) = q θ ( x ) ∇ θ log q θ ( x ) , we have ∇ θ D f ( p || q θ ) = E q θ (cid:20) ∇ θ f (cid:18) p ( x ) q θ ( x ) (cid:19) + f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ log q θ ( x ) (cid:21) = E q θ (cid:20) f (cid:48) (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ (cid:18) p ( x ) q θ ( x ) (cid:19) + f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ log q θ ( x ) (cid:21) = E q θ (cid:20) − f (cid:48) (cid:18) p ( x ) q θ ( x ) (cid:19)(cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ log q θ ( x ) + f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ log q θ ( x ) (cid:21) = − E q θ (cid:20) ρ f (cid:18) p ( x ) q θ ( x ) (cid:19) log q θ ( x ) (cid:21) , where ρ f ( t ) = f (cid:48) ( t ) t − f ( t ) . This proves Eq. (7).To prove Eq. (8), we note that for any function φ , we have by the reparamertization trick : ∇ θ E q θ [ φ ( x )] = E x ∼ q θ [ φ ( x ) ∇ θ log q θ ( x )] (score function) = E ξ ∼ q [ ∇ x φ ( x ) ∇ θ g θ ( ξ )] (reparameterization trick) , where we assume x ∼ q θ is generated by x = g θ ( ξ ) , ξ ∼ q .12aking φ ( x ) = ρ f ( p ( x ) /q θ ( x )) in Eq. (7), we have ∇ θ D f ( p || q θ ) = − E x ∼ q θ (cid:20) ρ f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ log q θ ( x ) (cid:21) = − E ξ ∼ q (cid:20) ∇ x ρ f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ g θ ( ξ ) (cid:21) = − E ξ ∼ q (cid:20) ρ (cid:48) f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ x (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ g θ ( ξ ) (cid:21) = − E ξ ∼ q (cid:20) ρ (cid:48) f (cid:18) p ( x ) q θ ( x ) (cid:19)(cid:18) p ( x ) q θ ( x ) (cid:19) ∇ x log (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ g θ ( ξ ) (cid:21) = − E ξ ∼ q (cid:20) γ f (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ x log (cid:18) p ( x ) q θ ( x ) (cid:19) ∇ θ g θ ( ξ ) (cid:21) , where γ f ( t ) = ρ (cid:48) f ( t ) t . C Tail-adaptive f -divergence with Score-Function Gradient Estimator Algorithm 2 summarizes our method using the score-function gradient estimator (7).
Algorithm 2
Variational Inference with Tail-adaptive f -Divergence (with Score Function Gradient)Goal: Find the best approximation of p ( x ) from { q θ : θ ∈ Θ } .Initialize θ , set an index β (e.g., β = − ). for iteration do Draw { x i } ni =1 ∼ q θ . Set ˆ¯ F ( t ) = (cid:80) nj =1 I ( p ( x j ) /q ( x j ) ≥ t ) /n , and ρ i = ˆ¯ F ( p ( x i ) /q ( x i )) β . Update θ ← θ + (cid:15) ∆ θ , where (cid:15) is stepsize, and ∆ θ = 1 z ρ n (cid:88) i =1 [ ρ i ∇ θ log q θ ( x i )] , where z ρ = (cid:80) ni =1 ρ i . end for More Results for Bayesian Neural Network
Table 2 shows more results in Bayesian networks with more choices of α in α -divergence. We cansee that our approach achieves the best performance in most of the cases.Average Test RMSEDataset β = − . β = − . α = − α = 0 α = 0 . α = 1 . α = 2 . α = + ∞ Boston
Naval β = − . β = − . α = − α = 0 α = 0 . α = 1 . α = 2 . α = + ∞ Boston -2.476 -2.523 -2.561 -2.547 -2.506 -2.493 -2.516 -2.509Concrete -3.099 -3.133 -3.171 -3.149 -3.103 -3.106 -3.116 -3.109Energy -1.758 -1.814 -1.946 -1.795 -1.854 -1.801 -1.828 -1.832Kin8nm 1.055 1.017 1.024 1.012 1.080 1.075 1.074
Naval -2.835 -2.842 -2.845 -2.845 -2.843 -2.839 -2.850 -2.842Wine -0.962 -0.956 -0.961 -0.959 -0.971 -0.968 -0.972 -0.971Yacht -1.711 -1.718 -2.201 -1.751 -1.875 -1.946 -1.963 -1.986Protein -2.921 -2.930 -2.934 -2.938 -2.928 -2.930 -2.947 -2.932Year -3.570 -3.597 -3.599 -3.600 -3.518 -3.529 -3.524 -3.524Table 2: Test RMSE and LL results for Bayesian neural network regression.
E Reinforcement Learning
In this section, we provide more information and results of the Reinforcement learning experiments,including comparisons of algorithms using score-function gradient estimators (Algorithm 2).
E.1 MuJoCo Environments
We test six MuJoCo environments in this paper:
HalfCheetah , Hopper , Swimmer(rllab) , Hu-manoid(rllab) , Walker , and
Ant , for which the dimensions of the action space are 6, 3, 2, 21,6, 8, respectively. Figure 4 shows examples of the environment used in our experiments.Figure 4:
MuJoCo environments used in our reinforcement learning experiments. From left to right: HalfChee-tah, Hopper, Swimmer(rllab), Humanoid(rllab), Walker, and Ant.
E.2 Different Choices of α In this section, we present the average reward of α -divergences with different choices of α on Hopperand Walker with both score-function and reparameterization gradient estimators. We can see fromFigure 5 that α = 0 . and α = + ∞ (denoted by α = max in the legends) perform consistently betterthan standard KL divergence ( α = 0 ), which is used the original SAC paper.14 opper (Score function) Hopper (Reparameterization)
Walker (Score function)
Walker (Reparameterization) A v e r a g e R e w a r d