# variational_inference_with_tailadaptive_fdivergence__d37c3dcb.pdf Variational Inference with Tail-adaptive f-Divergence Dilin Wang UT Austin dilin@cs.utexas.edu Hao Liu UESTC uestcliuhao@gmail.com Qiang Liu UT Austin lqiang@cs.utexas.edu Variational inference with α-divergences has been widely used in modern probabilistic machine learning. Compared to Kullback-Leibler (KL) divergence, a major advantage of using α-divergences (with positive α values) is their mass-covering property. However, estimating and optimizing α-divergences require to use importance sampling, which may have large or infinite variance due to heavy tails of importance weights. In this paper, we propose a new class of tail-adaptive fdivergences that adaptively change the convex function f with the tail distribution of the importance weights, in a way that theoretically guarantees finite moments, while simultaneously achieving mass-covering properties. We test our method on Bayesian neural networks, and apply it to improve a recent soft actor-critic (SAC) algorithm (Haarnoja et al., 2018) in deep reinforcement learning. Our results show that our approach yields significant advantages compared with existing methods based on classical KL and α-divergences. 1 Introduction Variational inference (VI) (e.g., Jordan et al., 1999; Wainwright et al., 2008) has been established as a powerful tool in modern probabilistic machine learning for approximating intractable posterior distributions. The basic idea is to turn the approximation problem into an optimization problem, which finds the best approximation of an intractable distribution from a family of tractable distributions by minimizing a divergence objective function. Compared with Markov chain Monte Carlo (MCMC), which is known to be consistent but suffers from slow convergence, VI provides biased results but is often practically faster. Combined with techniques like stochastic optimization (Ranganath et al., 2014; Hoffman et al., 2013) and reparameterization trick (Kingma & Welling, 2014), VI has become a major technical approach for advancing Bayesian deep learning, deep generative models and deep reinforcement learning (e.g., Kingma & Welling, 2014; Gal & Ghahramani, 2016; Levine, 2018). A key component of successful variational inference lies on choosing a proper divergence metric. Typically, closeness is defined by the KL divergence KL(q || p) (e.g., Jordan et al., 1999), where p is the intractable distribution of interest and q is a simpler distribution constructed to approximate p. However, VI with KL divergence often under-estimates the variance and may miss important local modes of the true posterior (e.g., Christopher, 2016; Blei et al., 2017). To mitigate this issue, alternative metrics have been studied in the literature, a large portion of which are special cases of f-divergence (e.g., Csiszár & Shields, 2004): Df(p || q) = Ex q where f : R+ R is any convex function. The most notable class of f-divergence that has been exploited in VI is α-divergence, which takes f(t) = tα/(α(α 1)) for α R \ {0, 1}. By choosing different α, we get a large number of well-known divergences as special cases, including the standard Work done at UT Austin 32nd Conference on Neural Information Processing Systems (Neur IPS 2018), Montréal, Canada. KL divergence objective KL(q || p) (α 0), the KL divergence with the reverse direction KL(p || q) (α 1) and the χ2 divergence (α = 2). In particular, the use of general α-divergence in VI has been widely discussed (e.g., Minka et al., 2005; Hernández-Lobato et al., 2016; Li & Turner, 2016); the reverse KL divergence is used in expectation propagation (Minka, 2001; Opper & Winther, 2005), importance weighted auto-encoders (Burda et al., 2016), and the cross entropy method (De Boer et al., 2005); χ2-divergence is exploited for VI (e.g., Dieng et al., 2017), but is more extensively studied in the context of adaptive importance sampling (IS) (e.g., Cappé et al., 2008; Ryu & Boyd, 2014; Cotter et al., 2015), since it coincides with the variance of the IS estimator with q as the proposal. A major motivation of using α-divergence contributes to its mass-covering property: when α > 0, the optimal approximation q tends to cover more modes of p, and hence better accounts for the uncertainty in p. Typically, larger values of α enforce stronger mass-covering properties. In practice, however, α divergence and its gradient need to be estimated empirically using samples from q. Using large α values may cause high or infinite variance in the estimation because it involves estimating the α-th power of the density ratio p(x)/q(x), which is likely distributed with a heavy or fat tail (e.g., Resnick, 2007). In fact, when q is very different from p, the expectation of ratio (p(x)/q(x))α can be infinite (that is, α-divergence does not exist). This makes it problematic to use large α values, despite the mass-covering property it promises. In addition, it is reasonable to expect that the optimal setting of α should vary across training processes and learning tasks. Therefore, it is desirable to design an approach to choose α adaptively and automatically as q changes during the training iterations, according to the distribution of the ratio p(x)/q(x). Based on theoretical observations on f-divergence and fat-tailed distributions, we design a new class of f-divergence which is tail-adaptive in that it uses different f functions according to the tail distribution of the density ratio p(x)/q(x) to simultaneously obtain stable empirical estimation and a strongest possible mass-covering property. This allows us to derive a new adaptive f-divergencebased variational inference by combining it with stochastic optimization and reparameterization gradient estimates. Our main method (Algorithm 1) has a simple form, which replaces the f function in (1) with a rank-based function of the empirical density ratio w = p(x)/q(x) at each gradient descent step of q, whose variation depends on the distribution of w and does not explode regardless the tail of w. Empirically, we show that our method can better recover multiple modes for variational inference. In addition, we apply our method to improve a recent soft actor-critic (SAC) algorithm (Haarnoja et al., 2018) in reinforcement learning (RL), showing that our method can be used to optimize multi-modal loss functions in RL more efficiently. 2 f-Divergence and Friends Given a distribution p(x) of interest, we want to approximate it with a simpler distribution from a family {qθ(x): θ Θ}, where θ is the variational parameter that we want to optimize. We approach this problem by minimizing the f-divergence between qθ and p: Df(p || qθ) = Ex qθ where f : R+ R is any twice differentiable convex function. It can be shown by Jensen s inequality that Df(p || q) 0 for any p and q. Further, if f(t) is strictly convex at t = 1, then Df(p || q) = 0 implies p = q. The optimization in (2) can be solved approximately using stochastic optimization in practice by approximating the expectation Ex qθ[ ] using samples drawing from qθ at each iteration. The f-divergence includes a large spectrum of important divergence measures. It includes KL divergence in both directions, KL(q || p) = Ex q , KL(p || q) = Ex q q(x) log p(x) which correspond to f(t) = log t and f(t) = t log t, respectively. KL(q || p) is the typical objective function used in variational inference; the reversed direction KL(p || q) is also used in various settings (e.g., Minka, 2001; Opper & Winther, 2005; De Boer et al., 2005; Burda et al., 2016). More generally, f-divergence includes the class of α-divergence, which takes fα(t) = tα/(α(α 1)), α R \ {0, 1} and hence Dfα(p || q) = 1 α(α 1)Ex q One can show that KL(q || p) and KL(p || q) are the limits of Dfα(q || p) when α 0 and α 1, respectively. Further, one obtain Helinger distance and χ2-divergence as α = 1/2 and α = 2, respectively. In particular, χ2-divergence (α = 2) plays an important role in adaptive importance sampling, because it equals the variance of the importance weight w = p(x)/q(x) and minimizing χ2-divergence corresponds to finding an optimal importance sampling proposal. 3 α-Divergence and Fat Tails A major motivation of using α divergences as the objective function for approximate inference is their mass-covering property (also known as the zero-avoiding behavior). This is because α-divergence is proportional to the α-th moment of the density ratio p(x)/q(x). When α is positive and large, large values of p(x)/q(x) are strongly penalized, preventing the case of q(x) p(x). In fact, whenever Dfα(p || q) < , we have p(x) > 0 imply q(x) > 0. This means that the probability mass and local modes of p are taken into account in q properly. Note that the case when α 0 exhibits the opposite property, that is, p(x) = 0 must imply q(x) = 0 to make Dfα(q||p) finite when α 0; this includes the typical KL divergence KL(q || p) (α = 0), which is often criticized for its tendency to under-estimate the uncertainty. Typically, using larger values of α enforces stronger mass-covering properties. In practice, however, larger values of α also increase the variance of the empirical estimators, making it highly challenging to optimize. In fact, the expectation in (4) may not even exist when α is too large. This is because the density ratio w := p(x)/q(x) often has a fat-tailed distribution. A non-negative random variable w is called fat-tailed2 (e.g., Resnick, 2007) if its tail probability Fw(t) := Pr(w t) is asymptotically equivalent to t α as t + for some finite positive number α (denoted by Fw(t) t α ), which means that Fw(t) = t α L(t), where L is a slowly varying function that satisfies limt + L(ct)/L(t) = 1 for any c > 0. Here α determines the fatness of the tail and is called the tail index of w. For a fat-tailed distribution with index α , its α-th moment exists only if α < α , that is, E[wα] < iff α < α . It turns out the density ratio w := p(x)/q(x), when x q, tends to have a fat-tailed distribution when q is more peaked than p. The example below illustrates this with simple Gaussian distributions. Example 3.1. Assume p(x) = N(x; 0, σ2 p) and q(x) = N(x; 0, σ2 q). Let x q and w = p(x)/q(x) the density ratio. If σp > σq, then w has a fat-tailed distribution with index α = σ2 p/(σ2 p σ2 q). On the other hand, if σp σq, then w is bounded and not fat-tailed (effectively, α = + ). By the definition above, if the importance weight w = p(x)/q(x) has a tail index α , the α-divergence Dfα(p || q) exists only if α < α . Although it is desirable to use α-divergence with large values of α as VI objective function, it is important to keep α smaller than α to ensure that the objective and gradient are well defined. The problem, however, is that the tail index α is unknown in practice, and may change dramatically (e.g., even from finite to infinite) as q is updated during the optimization process. This makes it suboptimal to use a pre-fixed α value. One potential way to address this problem is to estimate the tail index α empirically at each iteration using a tail index estimator (e.g., Hill et al., 1975; Vehtari et al., 2015). Unfortunately, tail index estimation is often challenging and requires a large number of samples. The algorithm may become unstable if α is over-estimated. 4 Hessian-based Representation of f-Divergence In this work, we address the aforementioned problem by designing a generalization of f-divergence in which f adaptively changes with p and q, in a way that always guarantees the existence of the 2Fat-tailed distributions is a sub-class of heavy-tailed distributions, which are distributions whose tail probabilities decay slower than exponential functions, that is, limt + exp(λt) Fw(t) = for all λ > 0. expectation, while simultaneous achieving (theoretically) strong mass-covering equivalent to that of the α-divergence with α = α . One challenge of designing such adaptive f is that the convex constraint over function f is difficult to express computationally. Our first key observation is that it is easier to specify a convex function f through its second order derivative f , which can be any non-negative function. It turns out f-divergence, as well as its gradient, can be conveniently expressed using f , without explicitly defining the original f. Proposition 4.1. 1) Any twice differentiable convex function f : R+ {0} R with finite f(0) can be decomposed into linear and nonlinear components as follows f(t) = (at + b) + Z 0 (t µ)+h(µ)dµ, (5) where h is a non-negative function, (t)+ = max(0, t), and a,b R. In this case, h = f (t), a = f (0) and b = f(0). Conversely, any non-negative function h and a, b R specifies a convex function. 2) This allows us to derive an alternative representation of f-divergence: Df(p || q) = Z 0 f (µ)Ex q where c := R 1 0 f (µ)(1 µ)dµ = f(1) f(0) f (0) is a constant. Proof. If f(t) = (at + b) + R 0 (t µ)+h(µ)dµ, calculation shows f (t) = a + Z t 0 h(µ)dµ, f (t) = h(t). Therefore, f is convex iff h is non-negative. See Appendix for the complete proof. Eq (6) suggests that all f-divergences are conical combinations of a set of special f-divergences of form Ex q[(p(x)/q(x) µ)+ f(1)] with f(t) = (t µ)+. Also, every f-divergence is completely specified by the Hessian f , meaning that adding f with any linear function at + b does not change Df(p || q). Such integral representation of f-divergence is not new; see e.g., Feldman & Osterreicher (1989); Osterreicher (2003); Liese & Vajda (2006); Reid & Williamson (2011); Sason (2018). For the purpose of minimizing Df(p || qθ) (θ Θ) in variational inference, we are more concerned with calculating the gradient, rather than the f-divergence itself. It turns out the gradient of Df(p || qθ) is also directly related to Hessian f in a simple way. Proposition 4.2. 1) Assume log qθ(x) is differentiable w.r.t. θ, and f is a differentiable convex function. For f-divergence defined in (2), we have θDf(p || qθ) = Ex qθ θ log qθ(x) , (7) where ρf(t) = f (t)t f(t) (equivalently, ρ f(t) = f (t)t if f is twice differentiable). 2) Assume x qθ is generated by x = gθ(ξ) where ξ q0 is a random seed and gθ is a function that is differentiable w.r.t. θ. Assume f is twice differentiable and x log(p(x)/qθ(x)) exists. We have θDf(p || qθ) = Ex=gθ(ξ),ξ q0 θgθ(ξ) x log(p(x)/qθ(x)) , (8) where γf(t) = ρ f(t)t = f (t)t2. The result above shows that the gradient of f-divergence depends on f through ρf or γf. Taking α-divergence (α / {0, 1}) as example, we have f(t) = tα/(α(α 1)), ρf(t) = tα/α, γf(t) = tα, all of which are proportional to the power function tα. For KL(q || p), we have f(t) = log t, yielding ρf(t) = log t 1 and γf(t) = 1; for KL(p || q), we have f(t) = t log t, yielding ρf(t) = t and γf(t) = t. The formulas in (7) and (8) are called the score-function gradient and reparameterization gradient (Kingma & Welling, 2014), respectively. Both equal the gradient in expectation, but are computationally different and yield empirical estimators with different variances. In particular, the score-function gradient in (7) is gradient-free in that it does not require calculating the gradient of the distribution p(x) of interest, while (8) is gradient-based in that it involves x log p(x). It has been shown that optimizing with reparameterization gradients tend to give better empirical results because it leverages the gradient information x log p(x), and yields a lower variance estimator for the gradient (e.g., Kingma & Welling, 2014). Our key observation is that we can directly specify f through any increasing function ρf, or nonnegative function γf in the gradient estimators, without explicitly defining f. Proposition 4.3. Assume f : R+ R is convex and twice differentiable, then 1) ρf in (7) is a monotonically increasing function on R+. In addition, for any differentiable increasing function ρ, there exists a convex function f such that ρf = ρ; 2) γf in (8) is non-negative on R+, that is, γf(t) 0, t R+. In addition, for any non-negative function γ, there exists a convex function f such that γf = γ; 3) if ρ f(t) is strictly increasing at t = 1 (i.e., ρ f(1) > 0), or γf(t) is strictly positive at t = 1 (i.e., γf(1) > 0), then Df(p || q) = 0 implies p = q. Proof. Because f is convex (f (t) 0), we have γf(t) = f (t)t2 0 and ρ f(t) = f (t)t 0 on t R+, that is, γf is non-negative and ρf is increasing on R+. If ρt is strictly increasing (or γf is strictly positive) at t = 1, we have f is strictly convex at t = 1, which guarantees Df(p || q) = 0 imply p = q. For non-negative function γ(t) (or increasing function ρ(t)) on R+, any convex function f whose second-order derivative equals γ(t)/t2 (or ρ f(t)/t) satisfies γf = γ (resp. ρf = ρ). 5 Safe f-Divergence with Inverse Tail Probability The results above show that it is sufficient to find an increasing function ρf, or a non-negative function γf to obtain adaptive f-divergence with computable gradients. In order to make the f-divergence safe , we need to find ρf or γf that adaptively depends on p and q such that the expectation in (7) and (8) always exists. Because the magnitude of θ log qθ(x), x log(p(x)/qθ(x)) and θgθ(ξ) are relatively small compared with the ratio p(x)/q(x), we can mainly consider designing function ρ (or γ) such that they yield finite expectation Ex q[ρ(p(x)/q(x))] < ; meanwhile, we should also keep the function large, preferably with the same magnitude as tα , to provide a strong mode-covering property. As it turns out, the inverse of the tail probability naturally achieves all these goals. Proposition 5.1. For any random variable w with tail distribution Fw(t) := Pr(w t) and tail index α , we have E[ Fw(w)β] < , for any β > 1. Also, we have Fw(t)β t βα , and Fw(t)β is always non-negative and monotonically increasing when β < 0. Proof. Simply note that E[ Fw(w)β] = R Fw(t)βd Fβ(t) = R 1 0 tβdt, which is finite only when β > 1. The non-negativity and monotonicity of Fw(t)β are obvious. Fw(t)β t βα directly follows the definition of the tail index. This motivates us to use Fw(t)β to define ρf or γf, yielding two versions of safe tail-adaptive f divergences. Note that here f is defined implicitly through ρf or γf. Although it is possible to derive the corresponding f and Df(p || q), there is no computational need to do so, since optimizing the objective function only requires calculating the gradient, which is defined by ρf or γf. Algorithm 1 Variational Inference with Tail-adaptive f-Divergence (with Reparameterization Gradient) Goal: Find the best approximation of p(x) from {qθ : θ Θ}. Assume x qθ is generated by x = gθ(ξ) where ξ is a random sample from noise distribution q0. Initialize θ, set an index β (e.g., β = 1). for iteration do Draw {xi}n i=1 qθ, generated by xi = gθ(ξi). Let wi = p(xi)/qθ(xi), ˆ Fw(t) = Pn j=1 I(wj t)/n, and set γi = ˆ Fw(wi)β. Update θ θ + ϵ θ, with ϵ is step size, and i=1 [γi θgθ(ξi) x log(p(xi)/qθ(xi))] , where zγ = In practice, the explicit form of Fw(t)β is unknown. We can approximate it based on empirical data drawn from q. Let {xi} be drawn from q and wi = p(xi)/q(xi), then we can approximate the tail probability with ˆ Fw(t) = 1 n Pn i=1 I(wi t). Intuitively, this corresponds to assigning each data point a weight according to the rank of its density ratio in the population. Substituting the empirical tail probability into the reparametrization gradient formula in (8) and running a gradient descent with stochastic approximation yields our main algorithm shown in Algorithm 1. The version with the score-function gradient is similar and is shown in Algorithm 2 in the Appendix. Both algorithms can be viewed as minimizing the implicitly constructed adaptive f-divergences, but correspond to using different f. Compared with typical VI with reparameterized gradients, our method assigns a weight ρi = ˆ Fw(wi)β, which is proportional #wβ i where #wi denotes the rank of data wi in the population {wi}. When taking 1 < β < 0, this allows us to penalize places with high ratio p(x)/q(x), but avoid to be overly aggressive. In practice, we find that simply taking β = 1 almost always yields the best empirical performance (despite needing β > 1 theoretically). By comparison, minimizing the classical α-divergence would have a weight of wα i ; if α is too large, the weight of a single data point becomes dominant, making gradient estimate unstable. 6 Experiments In this section, we evaluate our adaptive f-divergence with different models. We use reparameterization gradients as default since they have smaller variances (Kingma & Welling, 2014) and normally yield better performance than score function gradients. Our code is available at https://github.com/dilinwang820/adaptive-f-divergence. 6.1 Gaussian Mixture Toy Example We first illustrate the approximation quality of our proposed adaptive f-divergence on Gaussian mixture models. In this case, we set our target distribution to be a Gaussian mixture p(x) = Pk i=1 1 k N(x; νi, 1), for x Rd, where the elements of each mean vector νi is drawn from uniform([ s, s]). Here s can be viewed as controlling the Gaussianity of the target distribution: p reduces to standard Gaussian distribution when s = 0 and is increasingly multi-modal when s increases. We fix the number of components to be k = 10, and initialize the proposal distribution using q(x) = P20 i=1 wi N(x; µi, σ2 i ), where P20 i=1 wi = 1. We evaluate the mode-seeking ability of how q covers the modes of p using a mode-shift distance dist(p, q) := P10 i=1 minj ||νi µj||2/10, which is the average distance of each mode in p to its nearest mode in distribution q. The model is optimized using Adagrad with a constant learning rate 0.05. We use a minibatch of size 256 to approximate the gradient in each iteration. We train the model for 10, 000 iterations. To learn the component weights, we apply the Gumble-Softmax trick (Jang et al., 2017; Maddison et al., 2017) with a temperature of 0.1. Figure 1 shows the result when we obtain random mixtures p using s = 5, when the dimension d of x equals 2 and 10, respectively. (a) Mode-shift distance (b) Mean (c) Variance Avg. distance -2 -1 0 1 2 0.5 -2 -1 0 1 2 -2 -1 0 1 2 Adaptive(dim=2) Adaptive(dim=10) Alpha(dim=2) Alpha(dim=10) choice of α/β choice of α/β choice of α/β Figure 1: (a) plots the mode-shift distance between p and q; (b-c) show the MSE of mean and variance between the true posterior p and our approximation q, respectively. All results are averaged over 10 random trials. (a) Mode-shift distance (b) Mean (c) Variance Avg. distance 0 1 2 3 4 5 0 1 2 3 4 5 -3 0 1 2 3 4 5 Adaptive(beta=-1) Alpha(alpha=0) Alpha(alpha=0.5) Alpha(alpha=1.0) Non-Gaussianity s Non-Gaussianity s Non-Gaussianity s Figure 2: Results on randomly generated Gaussian mixture models. (a) plots the average mode-shift distance; (b-c) show the MSE of mean and variance. All results are averaged over 10 random trials. We can see that when the dimension is low (= 2), all algorithms perform similarly well. However, as we increase the dimension to 10, our approach with tail-adaptive f-divergence achieves the best performance. To examine the performance of variational approximation more closely, we show in Figure 2 the average mode-shift distance and the MSE of the estimated mean and variance as we gradually increase the non-Gaussianality of p(x) by changing s. We fix the dimension to 10. We can see from Figure 2 that when p is close to Gaussian (small s), all algorithms perform well; when p is highly non-Gaussian (large s), we find that our approach with adaptive weights significantly outperform other baselines. 6.2 Bayesian Neural Network We evaluate our approach on Bayesian neural network regression tasks. The datasets are collected from the UCI dataset repository3. Similarly to Li & Turner (2016), we use a single-layer neural network with 50 hidden units and Re LU activation, except that we take 100 hidden units for the Protein and Year dataset which are relatively large. We use a fully factorized Gaussian approximation to the true posterior and Gaussian prior for neural network weights. All datasets are randomly partitioned into 90% for training and 10% for testing. We use Adam optimizer (Kingma & Ba, 2015) with a constant learning rate of 0.001. The gradient is approximated by n = 100 draws of xi qθ and a minibatch of size 32 from the training data points. All results are averaged over 20 random partitions, except for Protein and Year, on which 5 trials are repeated. We summarize the average RMSE and test log-likelihood with standard error in Table 1. We compare our algorithm with α = 0 (KL divergence) and α = 0.5, which are reported as the best for this task in Li & Turner (2016). More comparisons with different choices of α are provided in the appendix. We can see from Table 1 that our approach takes advantage of an adaptive choice of f-divergence and achieves the best performance for both test RMSE and test log-likelihood in most of the cases. 3https://archive.ics.uci.edu/ml/datasets.html Average Test RMSE Average Test Log-likelihood dataset β = 1.0 α = 0.0 α = 0.5 β = 1.0 α = 0.0 α = 0.5 Boston 2.828 0.177 2.828 0.177 2.828 0.177 2.956 0.171 2.990 0.173 2.476 0.177 2.476 0.177 2.476 0.177 2.547 0.171 2.506 0.173 Concrete 5.371 0.115 5.371 0.115 5.371 0.115 5.592 0.124 5.381 0.111 3.099 0.115 3.099 0.115 3.099 0.115 3.149 0.124 3.103 0.111 Energy 1.377 0.034 1.377 0.034 1.377 0.034 1.431 0.029 1.531 0.047 1.758 0.034 1.758 0.034 1.758 0.034 1.795 0.029 1.854 0.047 Kin8nm 0.085 0.001 0.088 0.001 0.083 0.001 0.083 0.001 0.083 0.001 1.055 0.001 1.012 0.001 1.080 0.001 1.080 0.001 1.080 0.001 Naval 0.001 0.000 0.001 0.000 0.001 0.000 0.001 0.000 0.001 0.000 0.001 0.000 0.004 0.000 5.468 0.000 5.468 0.000 5.468 0.000 5.269 0.000 4.086 0.000 Combined 4.116 0.032 4.116 0.032 4.116 0.032 4.161 0.034 4.154 0.042 2.835 0.032 2.835 0.032 2.835 0.032 2.845 0.034 2.843 0.042 Wine 0.636 0.008 0.634 0.007 0.634 0.007 0.634 0.007 0.634 0.008 0.634 0.008 0.634 0.008 0.962 0.008 0.959 0.007 0.959 0.007 0.959 0.007 0.971 0.008 Yacht 0.849 0.059 0.849 0.059 0.849 0.059 0.861 0.056 1.146 0.092 1.711 0.059 1.711 0.059 1.711 0.059 1.751 0.056 1.875 0.092 Protein 4.487 0.019 4.487 0.019 4.487 0.019 4.565 0.026 4.564 0.040 2.921 0.019 2.921 0.019 2.921 0.019 2.938 0.026 2.928 0.040 Year 8.831 0.037 8.831 0.037 8.831 0.037 8.859 0.036 8.985 0.042 3.570 0.037 3.600 0.036 3.518 0.042 3.518 0.042 3.518 0.042 Table 1: Average test RMSE and log-likelihood for Bayesian neural regression. 6.3 Application in Reinforcement Learning We now demonstrate an application of our method in reinforcement learning, applying it as an inner loop to improve a recent soft actor-critic(SAC) algorithm (Haarnoja et al., 2018). We start with a brief introduction of the background of SAC and then test our method in Mu Jo Co 4 environments. Reinforcement Learning Background Reinforcement learning considers the problem of finding optimal policies for agents that interact with uncertain environments to maximize the long-term cumulative reward. This is formally framed as a Markov decision process in which agents iteratively take actions a based on observable states s, and receive a reward signal r(s, a) immediately following the action a performed at state s. The change of the states is governed by an unknown environmental dynamic defined by a transition probability T(s |s, a). The agent s action a is selected by a conditional probability distribution π(a|s) called policy. In policy gradient methods, we consider a set of candidate policies πθ(a|s) parameterized by θ and obtain the optimal policy by maximizing the expected cumulative reward J(θ) = Es dπ,a π(a|s) [r(s, a)] , where dπ(s) = P t=1 γt 1Pr(st = s) is the unnormalized discounted state visitation distribution with discount factor γ (0, 1). Soft Actor-Critic (SAC) is an off-policy optimization algorithm derived based on maximizing the expected reward with an entropy regularization. It iteratively updates a Q-function Q(a, s), which predicts that cumulative reward of taking action a under state s, as well as a policy π(a|s) which selects action a to maximize the expected value of Q(s, a). The update rule of Q(s, a) is based on a variant of Q-learning that matches the Bellman equation, whose detail can be found in Haarnoja et al. (2018). At each iteration of SAC, the update of policy π is achieved by minimizing KL divergence πnew = arg min π Es d [KL(π( |s) || p Q( |s))] , (9) p Q(a|s) = exp 1 τ (Q(a, s) V (s)) , (10) where τ is a temperature parameter, and V (s) = τ log R a exp(Q(a, s)/τ), serving as a normalization constant here, is a soft-version of value function and is also iteratively updated in SAC. Here, d(s) is a visitation distribution on states s, which is taken to be the empirical distribution of the states in the current replay buffer in SAC. We can see that (9) can be viewed as a variational inference problem on conditional distribution p Q(a|s), with the typical KL objective function (α = 0). SAC With Tail-adaptive f-Divergence To apply f-divergence, we first rewrite (9) to transform the conditional distributions to joint distributions. We define joint distribution p Q(a, s) = exp((Q(a, s) V (s))/τ)d(s) and qπ(a, s) = π(a|s)d(s), then we can show that Es d[KL(π( |s) || p Q( |s))] = KL(qπ || p Q). This motivates us to extend the objective function in (9) to more general f-divergences, Df(p Q || qπ) = Es d Ea|s π f exp((Q(a, s) V (s))/τ) 4http://www.mujoco.org/ Ant Half Cheetah Humanoid(rllab) Average Reward 0M 2M 4M 6M 8M 10M 500 0M 2M 4M 6M 8M 10M 0 0M 2M 4M 6M 8M 10M 0 Walker Hopper Swimmer(rllab) Average Reward 0M 1M 2M 3M 4M 5M 0 0.0M 0.5M 1.0M 1.5M 2.0M 0 0.0M 0.1M 0.2M 0.3M 0.4M 0.5M α=0.0 α=0.5 α=max β=-1.0 Figure 3: Soft Actor Critic (SAC) with policy updated by Algorithm 1 with β = 1, or α-divergence VI with different α (α = 0 corresponds to the original SAC). The reparameterization gradient estimator is used in all the cases. In the legend, α = max denotes setting α = + in α-divergence. By using our tail-adaptive f-divergence, we can readily apply our Algorithm 1 (or Algorithm 2 in the Appendix) to update π in SAC, allowing us to obtain π that counts the multi-modality of Q(a, s) more efficiently. Note that the standard α-divergence with a fixed α also yields a new variant of SAC that is not yet studied in the literature. Empirical Results We follow the experimental setup of Haarnoja et al. (2018). The policy π, the value function V (s), and the Q-function Q(s, a) are neural networks with two fully-connected layers of 128 hidden units each. We use Adam (Kingma & Ba, 2015) with a constant learning rate of 0.0003 for optimization. The size of the replay buffer for Half Cheetah is 107, and we fix the size to 106 on other environments in a way similar to Haarnoja et al. (2018). We compare with the original SAC (α = 0), and also other α-divergences, such as α = 0.5 and α = (the α = max suggested in Li & Turner (2016)). Figure 3 summarizes the total average reward of evaluation rollouts during training on various Mu Jo Co environments. For non-negative α settings, methods with larger α give higher average reward than the original KL-based SAC in most of the cases. Overall, our adaptive f-divergence substantially outperforms all other α-divergences on all of the benchmark tasks in terms of the final performance, and learns faster than all the baselines in most environments. We find that our improvement is especially significant on high dimensional and complex environments like Ant and Humanoid. 7 Conclusion In this paper, we present a new class of tail-adaptive f-divergence and exploit its application in variational inference and reinforcement learning. Compared to classic α-divergence, our approach guarantees finite moments of the density ratio and provides more stable importance weights and gradient estimates. Empirical results on Bayesian neural networks and reinforcement learning indicate that our approach outperforms standard α-divergence, especially for high dimensional multi-modal distribution. Acknowledgement This work is supported in part by NSF CRII 1830161. We would like to acknowledge Google Cloud for their support. Blei, David M, Kucukelbir, Alp, and Mc Auliffe, Jon D. Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518):859 877, 2017. Burda, Yuri, Grosse, Roger, and Salakhutdinov, Ruslan. Importance weighted autoencoders. International Conference on Learning Representations (ICLR), 2016. Cappé, Olivier, Douc, Randal, Guillin, Arnaud, Marin, Jean-Michel, and Robert, Christian P. Adaptive importance sampling in general mixture classes. Statistics and Computing, 18(4):447 459, 2008. Christopher, M Bishop. Pattern Recognition and Machine Learning. Springer-Verlag New York, 2016. Cotter, Colin, Cotter, Simon, and Russell, Paul. Parallel adaptive importance sampling. ar Xiv preprint ar Xiv:1508.01132, 2015. Csiszár, I. and Shields, P.C. Information theory and statistics: A tutorial. Foundations and Trends in Communications and Information Theory, 1(4):417 528, 2004. De Boer, Pieter-Tjerk, Kroese, Dirk P, Mannor, Shie, and Rubinstein, Reuven Y. A tutorial on the cross-entropy method. Annals of operations research, 134(1):19 67, 2005. Dieng, Adji Bousso, Tran, Dustin, Ranganath, Rajesh, Paisley, John, and Blei, David. Variational inference via χ upper bound minimization. In Advances in Neural Information Processing Systems (NIPS), pp. 2732 2741, 2017. Feldman, Dorian and Osterreicher, Ferdinand. A note on f-divergences. Studia Sci.\Math.\Hungar., 24:191 200, 1989. Gal, Yarin and Ghahramani, Zoubin. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In international conference on machine learning (ICML), pp. 1050 1059, 2016. Haarnoja, Tuomas, Zhou, Aurick, Abbeel, Pieter, and Levine, Sergey. Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor. International Conference on Machine Learning (ICML), 2018. Hernández-Lobato, José Miguel, Li, Yingzhen, Rowland, Mark, Hernández-Lobato, Daniel, Bui, Thang, and Turner, Richard Eric. Black-box α-divergence minimization. International Conference on Machine Learning (ICML), 2016. Hill, Bruce M et al. A simple general approach to inference about the tail of a distribution. The annals of statistics, 3(5):1163 1174, 1975. Hoffman, Matthew D, Blei, David M, Wang, Chong, and Paisley, John. Stochastic variational inference. The Journal of Machine Learning Research, 14(1):1303 1347, 2013. Jang, Eric, Gu, Shixiang, and Poole, Ben. Categorical reparameterization with Gumbel-softmax. International Conference on Learning Representations (ICLR), 2017. Jordan, Michael I, Ghahramani, Zoubin, Jaakkola, Tommi S, and Saul, Lawrence K. An introduction to variational methods for graphical models. Machine learning, 37(2):183 233, 1999. Kingma, Diederik P and Ba, Jimmy. Adam: A method for stochastic optimization. International Conference on Learning Representations (ICLR), 2015. Kingma, Diederik P and Welling, Max. Auto-encoding variational Bayes. International Conference on Learning Representations (ICLR), 2014. Levine, Sergey. Reinforcement learning and control as probabilistic inference: Tutorial and review. ar Xiv preprint ar Xiv:1805.00909, 2018. Li, Yingzhen and Turner, Richard E. Rényi divergence variational inference. In Advances in Neural Information Processing Systems (NIPS), pp. 1073 1081, 2016. Liese, Friedrich and Vajda, Igor. On divergences and informations in statistics and information theory. IEEE Transactions on Information Theory, 52(10):4394 4412, 2006. Maddison, Chris J, Mnih, Andriy, and Teh, Yee Whye. The concrete distribution: A continuous relaxation of discrete random variables. International Conference on Learning Representations (ICLR), 2017. Minka, Thomas P. Expectation propagation for approximate Bayesian inference. In Proceedings of the Seventeenth conference on Uncertainty in artificial intelligence (UAI), pp. 362 369. Morgan Kaufmann Publishers Inc., 2001. Minka, Tom et al. Divergence measures and message passing. Technical report, Microsoft Research, 2005. Opper, Manfred and Winther, Ole. Expectation consistent approximate inference. Journal of Machine Learning Research, 6(Dec):2177 2204, 2005. Osterreicher, Ferdinand. f-divergences representation theorem and metrizability. Inst. Math., Univ. Salzburg, Salzburg, Austria, 2003. Ranganath, Rajesh, Gerrish, Sean, and Blei, David. Black box variational inference. In Artificial Intelligence and Statistics, pp. 814 822, 2014. Reid, Mark D and Williamson, Robert C. Information, divergence and risk for binary experiments. Journal of Machine Learning Research, 12(Mar):731 817, 2011. Resnick, Sidney I. Heavy-tail phenomena: probabilistic and statistical modeling. Springer Science & Business Media, 2007. Ryu, Ernest K and Boyd, Stephen P. Adaptive importance sampling via stochastic convex programming. ar Xiv preprint ar Xiv:1412.4845, 2014. Sason, Igal. On f-divergences: Integral representations, local behavior, and inequalities. Entropy, 20 (5):383, 2018. Vehtari, Aki, Gelman, Andrew, and Gabry, Jonah. Pareto smoothed importance sampling. ar Xiv preprint ar Xiv:1507.02646, 2015. Wainwright, Martin J, Jordan, Michael I, et al. Graphical models, exponential families, and variational inference. Foundations and Trends R in Machine Learning, 1(1 2):1 305, 2008.