# raoblackwellizing_the_straightthrough_gumbelsoftmax_gradient_estimator__3c4fe327.pdf Published as a conference paper at ICLR 2021 RAO-BLACKWELLIZING THE STRAIGHT-THROUGH GUMBEL-SOFTMAX GRADIENT ESTIMATOR Max B. Paulus ETH Z urich MPI for Intelligent Systems, T ubingen max.paulus@inf.ethz.ch Chris J. Maddison University of Toronto Vector Institute cmaddis@cs.toronto.edu Andreas Krause ETH Z urich krausea@ethz.ch Gradient estimation in models with discrete latent variables is a challenging problem, because the simplest unbiased estimators tend to have high variance. To counteract this, modern estimators either introduce bias, rely on multiple function evaluations, or use learned, input-dependent baselines. Thus, there is a need for estimators that require minimal tuning, are computationally cheap, and have low mean squared error. In this paper, we show that the variance of the straight-through variant of the popular Gumbel-Softmax estimator can be reduced through Rao Blackwellization without increasing the number of function evaluations. This provably reduces the mean squared error. We empirically demonstrate that this leads to variance reduction, faster convergence, and generally improved performance in two unsupervised latent variable models. 1 INTRODUCTION Models with discrete latent variables are common in machine learning. Discrete random variables provide an effective way to parameterize multi-modal distributions, and some domains naturally have latent discrete structure (e.g, parse trees in NLP). Thus, discrete latent variable models can be found across a diverse set of tasks, including conditional density estimation, generative text modelling (Yang et al., 2017), multi-agent reinforcement learning (Mordatch & Abbeel, 2017; Lowe et al., 2017) or conditional computation (Bengio et al., 2013; Davis & Arel, 2013). The majority of these models are trained to minimize an expected loss using gradient-based optimization, so the problem of gradient estimation for discrete latent variable models has received considerable attention over recent years. Existing estimation techniques can be broadly categorized into two groups, based on whether they require one loss evaluation (Glynn, 1990; Williams, 1992; Bengio et al., 2013; Mnih & Gregor, 2014; Chung et al., 2017; Maddison et al., 2017; Jang et al., 2017; Grathwohl et al., 2018) or multiple loss evaluations (Gu et al., 2016; Mnih & Rezende, 2016; Tucker et al., 2017) per estimate. These estimators reduce variance by introducing bias or increasing the computational cost with the overall goal being to reduce the total mean squared error. Because loss evaluations are costly in the modern deep learning age, single evaluation estimators are particularly desirable. This family of estimators can be further categorized into those that relax the discrete randomness in the forward pass of the model (Maddison et al., 2017; Jang et al., 2017; Paulus et al., 2020) and those that leave the loss computation unmodified (Glynn, 1990; Williams, 1992; Bengio et al., 2013; Chung et al., 2017; Mnih & Gregor, 2014; Grathwohl et al., 2018). The ones that do not modify the loss computation are preferred, because they avoid the accumulation of errors in the forward direction and they allow the model to exploit the sparsity of discrete computation. Thus, there is a particular need for single evaluation estimators that do not modify the loss computation. Work done partly at the Institute for Advanced Study, Princeton, NJ. Published as a conference paper at ICLR 2021 In this paper we introduce such a method. In particular, we propose a Rao-Blackwellization scheme for the straight-through variant of the Gumbel-Softmax estimator (Jang et al., 2017; Maddison et al., 2017), which comes at a minimal cost, and does not increase the number of function evaluations. The straight-through Gumbel-Softmax estimator (ST-GS, Jang et al., 2017) is a lightweight stateof-the-art single-evaluation estimator based on the Gumbel-Max trick (see Maddison et al., 2014, and references therein). The ST-GS uses the argmax over Gumbel random variables to generate a discrete random outcome in the forward pass. It computes derivatives via backpropagation through a tempered softmax of the same Gumbel sample. Our Rao-Blackwellization scheme is based on the key insight that there are many configurations of Gumbels corresponding to the same discrete random outcome and that these can be marginalized over with Monte Carlo estimation. By design, there is no need to re-evaluate the loss and the additional cost of our estimator is linear only in the number of Gumbels needed for a single forward pass. As we show, the Rao-Blackwell theorem implies that our estimator has lower mean squared error than the vanilla ST-GS. We demonstrate the effectiveness of our estimator in unsupervised parsing on the List Ops dataset (Nangia & Bowman, 2018) and on a variational autoencoder loss (Kingma & Welling, 2013; Rezende et al., 2014). We find that in practice our estimator trains faster and achieves better test set performance. The magnitude of the improvement depends on several factors, but is particularly pronounced at small batch sizes and low temperatures. 2 BACKGROUND For clarity, we consider the following simplified scenario. Let D pθ be a discrete random variable D {0, 1}n in a one-hot encoding, P Di = 1, with distribution given by pθ(D) exp(DT θ) where θ Rn. Given a continuously differentiable f : R2n R, we wish to minimize, min θ E[f(D, θ)], (1) where the expectation is taken over all of the randomness. In general θ may be computed with some neural network, so our aim is to derive estimators of the total derivative of the expectation with respect to θ for use in stochastic gradient descent. This framework covers most simple discrete latent variable models, including variational autoencoders (Kingma & Welling, 2013; Rezende et al., 2014). The REINFORCE estimator (Glynn, 1990; Williams, 1992) is unbiased (under certain smoothness assumptions) and given by: REINF := f(D, θ) log pθ(D) θ + f(D, θ) Without careful use of control variates (Mnih & Gregor, 2014; Tucker et al., 2017; Grathwohl et al., 2018), the REINFORCE estimator tends to have prohibitively high variance. To simplify exposition we assume henceforth that f(D, θ) = f(D) does not depend on θ, because the dependence of f(D, θ) on θ is accounted for in the second term of (2), which is shared by most estimators and generally has low variance. One strategy for reducing the variance is to introduce bias through a relaxation (Jang et al., 2017; Maddison et al., 2017). Define the tempered softmax softmaxτ : Rn Rn by softmaxτ(x)i = exp(xi/τ)/ Pn j=1 exp(xj/τ). The relaxations are based on the observation that the sampling of D can be reparameterized using Gumbel random variables and the zero-temperature limit of the tempered softmax under the coupling: D = lim τ 0 Sτ; Sτ = softmaxτ(θ + G) (3) where G is a vector of i.i.d. Gi Gumbel random variables. At finite temperatures Sτ is known as a Gumbel-Softmax (GS) (Jang et al., 2017) or concrete (Maddison et al., 2017) random variable, and the relaxed loss E[f(Sτ, θ)] admits the following reparameterization gradient estimator for τ > 0:1 GS := f(Sτ) d softmaxτ(θ + G) 1For a function f(x1, x2), f(z1, z2)/ x1 is the partial derivative (e.g., a gradient vector) of f in the first variable evaluated at z1, z2. For a function g(θ), dg/dθ is the total derivative of g in θ. For example, d softmaxτ(θ + G)/dθ is the Jacobian of the tempered softmax evaluated at the random variable θ + G. Published as a conference paper at ICLR 2021 This is an unbiased estimator of the gradient of E[f(Sτ, θ)], but a biased estimator of our original problem (1). For this to be well-defined f must be defined on the interior of the simplex (where Sτ sits). This estimator has the advantage that it is easy to implement and generally low-variance, but the disadvantage that it modifies the forward computation of f and is biased. Henceforth, we assume D, Sτ, and G are coupled almost surely through (3). Another popular family of estimators are the so-called straight-through estimators (c.f., Bengio et al., 2013; Chung et al., 2017). In this family, the forward computation of f is unchanged, but backpropagation is computed through a surrogate. One popular variant takes as a surrogate the tempered probabilities of D, resulting in the slope-annealed straight-through estimator (ST): D d softmaxτ(θ) For binary D, a lower bias variant of this estimator (Fou ST) was proposed in Pervez et al. (2020). The most popular straight-through estimator is known as the straight-through Gumbel-Softmax (STGS, Jang et al., 2017). The surrogate for ST-GS is Sτ, whose Gumbels are coupled to D through (3): STGS := f(D) D d softmaxτ(θ + G) The straight-through family has the advantage that they tend to be low-variance and f need not be defined on the interior of the simplex (although f must be differentiable at the corners). This family has the disadvantage that they are not known to be unbiased estimators of any gradient. These estimators are quite popular in practice, because they preserve the forward computation of f, which prevents the forward propagation of errors and maintains sparsity (Choi et al., 2017; Chung et al., 2017; Bengio et al., 2013). All of the estimators discussed in this paper can be computed by any of the standard automatic differentiation software packages using a single evaluation of f on a realization of D or some underlying randomness. We present implementation details for these and our Gumbel-Rao estimator in the Appendix, emphasizing the surrogate loss framework (Schulman et al., 2015; Weber et al., 2019) and considering the multiple stochastic layer case not covered by (1). 3 GUMBEL-RAO GRADIENT ESTIMATOR 3.1 RAO-BLACKWELLIZATION OF ST-GUMBEL-SOFTMAX We now derive our Rao-Blackwelization scheme for the ST-GS estimator. Our approach is based on the observation that there is a many-to-one relationship between realizations of θ + G and D in the coupling described by (3) and that the variance introduced by θ + G can be marginalized out. The resulting estimator, which we call the Gumbel-Rao (GR) estimator, is guaranteed by the Rao-Blackwell theorem to have lower variance than ST-GS. In the next subsection we turn to the practical question of carrying out this marginalization. In the Gumbel-max trick (3), D is a one-hot indicator of the index of arg maxi {θi + Gi}. Because this argmax operation is non-invertible, there are many configurations of θ + G that correspond to a single D outcome. Consider an alternate factorization of the joint distribution of (θ + G, D): first sample D pθ, and then θ + G given D. In this view, the Gumbels are auxillary random variables, at which the Jacobian of the tempered softmax is evaluated and which locally increase the variance of the estimator. This local variance can be removed by marginalization. This is the key insight of our GR estimator, which is given by, D E d softmaxτ(θ + G) It is not too difficult to see that GR = E [ STGS|D]. By the tower rule of expectation, GR has the same expected value as ST-GS and is an instance of a Rao-Blackwell estimator (Blackwell, 1947; Rao, 1992). Thus, it has the same mean as ST-GS, but a lower variance. Taken together, these facts imply that GR enjoys a lower mean squared error (not a lower bias) than ST-GS. Published as a conference paper at ICLR 2021 Proposition 1. Let STGS and GR be the estimators defined in (6) and (7). Let θ := d E[f(D)]/dθ be the true gradient that we are trying to estimate. We have E h GR θ 2i E h STGS θ 2i . (8) Proof. The proposition follows from Jensen s inequality and the linearity of expectations, see C.1. While GR is only guaranteed to reduce the variance of ST-GS, Proposition 1 guarantees that, as a function of τ, the MSE of GR is a pointwise lower bound on ST-GS. This means GR can be used for estimation at temperatures, where ST-GS has low bias but prohibitively high variance. Thus, GR extends the region of suitable temperatures over which one can tune. This allows a practitioner to explore an expanded set when trading-off of bias and variance. Empirically, lower temperatures tend to reduce the bias of ST-GS, but we are not aware of any work that studies the convergence of the derivative in the temperature limit. In our experiments, we observe that our estimator facilitates training at lower temperatures to improve in both bias and variance over ST-GS. Thus, our estimator retains the favourable properties of ST-GS (single, unmodified evaluation of f) while improving its performance. 3.2 MONTE CARLO APPROXIMATION The GR estimator requires computing the expected value of the Jacobian of the tempered softmax over the distribution θ + G|D. Unfortunately, an analytical expression for this is only available in the simplest cases.2 In this section we provide a simple Monte Carlo (MC) estimator with sample size K for E[d Sτ/dθ|D], which we call the Gumbel-Rao Monte Carlo Estimator (GR-MCK). This estimator can be computed locally at a cost that only scales like n K (the arity of D times K). They key property exploited by GR-MCK is that θ + G|D can be reparameterized in the following closed form. Given a realization of D such that Di = 1, Z(θ) = Pn i=1 exp(θi), and Ej exponential i.i.d., we have the following equivalence in distribution (Maddison et al., 2014; Maddison, 2016; Tucker et al., 2017). θj + Gj|D d= ( log (Ej) + log Z(θ) if j = i log Ej exp(θj) + Ei Z(θ) o.w. (9) With this in mind, we define the GR-MCK estimator: GRMCK := f(D) d softmaxτ(Gk θ) dθ where Gk θ θ + G|D i.i.d. using the reparameterization (9). For the case K = 1, our estimator reduces to the standard ST-GS estimator. The cost for drawing multiple samples Gk θ θ + G|D scales only linearly in the arity of D and is usually negligible in modern applications, where the bulk of computation accrues from the computation of f. Moreover, drawing multiple samples of θ + G|D can easily be parallelized on modern workstations (GPUs, etc.). Our estimator remains a single-evaluation estimator under this scheme, because the loss function f is still only evaluated at D. Finally, as with GR, the GR-MCK is guaranteed to improve in MSE over ST-GS for any K 1, as confirmed in Proposition 2. Proposition 2. Let STGS and GRMCK be the estimators defined in (6) and (10). Let θ := d E[f(D)]/dθ be the true gradient that we are trying to estimate. For all K 1, we have E h GRMCK θ 2i E h STGS θ 2i . (11) Proof. The proposition follows from Jensen s inequality and the linearity of expectations, see C.2. 2For example, in the case of n = 2 (binary) and τ = 1 an analytical expression for the GR estimator is available. Published as a conference paper at ICLR 2021 3.3 VARIANCE REDUCTION IN MINIBATCHES The variance of GR-MCK can be reduced by increasing K or by averaging B i.i.d. samples of the GR-MCK estimator. An average of i.i.d. samples b GRMCK for b {1, . . . , B} is a generalization of minibatching by sampling data points with replacement. In particular, θ may depend on an additional source of randomness, i.e., θ = h(X) for X P. In this case, STGS is a random variable that depends not only on D and G, but also on X. In this subsection, we consider the effect of increasing K and B separately. Expectations are taken over all the randomness. Let b GRMCK be i.i.d. as GRMCK for b {1, . . . , B} and define the following minibatched GR-MCK estimator: 1:B GRMCK := 1 b=1 b GRMCK. (12) Proposition 3 summarizes the scaling of the variance of (12), and is an elementary application of the law of total variance. Proposition 3. Let STGS, GR and 1:B GRMCK be the estimators defined in (6), (7) and (12). We have 1:B GRMCK i = E [var [ STGS|D, X]] BK + var [ GR] where var is the trace of the covariance matrix. Proof. The proposition follows directly from the law of total variance, see C.3. As expected the total variance of 1:B GRMCK decreases like 1/B. The key point of Proposition 3 is that the component of the variance that K reduces can also be reduced by increasing the batch size B. This suggests that the effect of GR-MCK will be most pronounced at small batch sizes. Proposition 3 also indicates that there are diminishing returns to increasing K for a fixed batch size B, such that the variance of GR-MCK will eventually be dominated by the right-hand term of (13). In our experimental section, we explore various K and study the effect on gradient estimation in more detail. Finally, we note that the choice of a Monte Carlo scheme to approximate E [d Sτ/dθ|D] permits the use of additional well-known variance reduction methods to improve the estimation properties of our gradient estimator. For example, antithetic variates or importance sampling are sensible methods to explore in this setting (Kroese et al., 2013). For low-dimensional discrete random variables, Gaussian quadrature or other numerical methods could be employed. However, we found the simple Monte Carlo scheme described above effective in practice and report results based on this procedure in the experimental section. 4 RELATED WORK The idea of using Rao-Blackwellization to reduce the variance of gradient estimators for discrete latent variable models has been explored in machine learning. For example, Liu et al. (2018) describe a sum-and-sample style estimator that analytically computes part of the expectation to reduce the variance of the gradient estimates. The favorable properties of their estimator are due to the Rao Blackwell theorem. Kool et al. (2020) describe a gradient estimator based on sampling without replacement. Their estimator emerges naturally as the Rao-Blackwell estimator of the importanceweighted estimator (Vieira, 2017) and the estimator described by Liu et al. (2018). Both of these estimators rely on multiple function evaluations to compute a gradient estimate. In contrast, our work is the first to consider Rao-Blackwellisation in the context of a single-evaluation estimator. Recently, Paulus et al. (2020) extend the Gumbel-Softmax gradient estimator to other discrete structures. Our approach can be used to reduce the variance of the corresponding straight-through variants, when an efficient reparameterization of the perturbation conditional on the discrete structure is available (Gane et al., 2014). Published as a conference paper at ICLR 2021 (a) Objective function (b) log10(var) (τ=0.1) (c) log10(var) (τ=0.5) (d) log10(var) (τ=1.0) Figure 1: Our estimator (GR-MCK) effectively reduces the variance over the entire simplex and is particularly effective at low temperatures. Contours for the quadratic programme in three dimensions (1a) and difference in log10-trace of the covariance matrix between ST-GS and GR-MC1000 at different temperatures (1b, 1c, 1d). Warmer means difference is larger. 5 EXPERIMENTS 5.1 PROTOCOL In this section, we study the effectiveness of our gradient estimator in practice. In particular, we evaluate its performance with respect to the temperature τ, the number of MC samples K and the batch size B. We measure the variance reduction and improvements in MSE our estimator achieves in practice, and assess whether its lower variance gradient estimates accelerate the convergence on the objective or improve final test set performance. Our focus is on single-evaluation gradient estimation and we compare against other non-relaxing estimators (ST, Fou ST, ST-GS and REINFORCE with a running mean as a baseline) and relaxing estimators (GS), where permissible. Experimental details are given in Appendix D. First, we consider a toy example which allows us to explore and visualize the variance of our estimator and suggests that it is particularly effective at low temperatures. Next, we evaluate the effect of τ and K in a latent parse tree task which does not permit the use of relaxed gradient estimators. Here, our estimator facilitates training at low temperatures to improve overall performance and is effective even with few MC samples. Finally, we train variational auto-encoders with discrete latent variables (Kingma & Welling, 2013; Rezende et al., 2014). Our estimator yields improvements at small batch sizes and obtains competitive or better performance than the GS estimator at the largest arity. 5.2 QUADRATIC PROGRAMMING ON THE SIMPLEX As a toy problem, we consider the problem of minimizing a quadratic program (p c) Q(p c) over the probability simplex n 1 = {p Rn : pi 0, Pn i=1 pi = 1} for Q Rn n positive-definite and c Rn. This problem may be reframed as the following stochastic optimization problem, min p n 1 E[(D c) A(p)(D c)], where D Discrete(p) and Aii(p) = (pi ci)2 pi 2pici+c2 i Qii and Aij(p) = (pi ci)(pj cj) cicj picj cipj Qij for i = j. While solving the above problem is simple using standard methods, it provides a useful testbed to evaluate the effectiveness of our variance reduction scheme. For this purpose, we consider Qij = exp ( 2|i j|) and ci = 1 3 in three dimensions. Our estimator reduces the variance in the gradient estimation over the entire simplex and is particularly effective at low temperatures in this problem. In Figure 1, we compare the log10-trace of the covariance matrix of ST-GS and GR-MC1000 at three different temperatures and display their difference over the entire domain. The improvement is universal. The pattern is not always intuitive (oval bull s eyes), despite the simplicity of the objective function. Compared with ST-GS, our estimator on this example appears more effective closer to the corners and edges, which is important for learning discrete distributions. At lower temperatures, the difference between the two estimators becomes particularly acute. This suggests that our estimator may train better at lower temperatures and be more responsive to optimizing over the temperature to successfully trade off bias and variance. Published as a conference paper at ICLR 2021 Table 1: Our estimator (GR-MCK) facilitates training at lower temperatures with improved performance on the latent parse tree task. Best test classification accuracy on the List Ops dataset selected on the validation set. Best estimator at given temperature in bold, best estimator across temperatures in italics. Higher is better. L 10 L 25 L 50 ESTIMATOR τ = 0.01 τ = 0.1 τ = 1.0 τ = 0.01 τ = 0.1 τ = 1.0 τ = 0.01 τ = 0.1 τ = 1.0 ST-GS 38.8 59.3 65.8 41.2 57.1 60.2 46.8 56.8 59.6 GR-MC10 66.4 66.9 66.7 60.7 60.8 60.9 58.7 59.1 59.6 GR-MC100 65.6 66.3 65.9 60.0 61.3 61.2 59.6 59.1 59.6 GR-MC1000 66.5 67.1 67.0 60.2 60.9 61.2 60.0 59.8 59.9 5.3 UNSUPERVISED PARSING ON LISTOPS Straight-through estimators feature prominently in NLP (Martins et al., 2019) where latent discrete structure arises naturally, but the use of relaxations is often infeasible. Therefore, we evaluate our estimator in a latent parse tree task on subsets of the List Ops dataset (Nangia & Bowman, 2018). This dataset contains sequences of prefix arithmetic expressions x (e.g., max[ 3 min[ 8 2 ]]) that evaluate to an integer y {0, 1, . . . 9}. The arithmetic syntax induces a latent parse tree T. We consider the model by (Choi et al., 2017) that learns a distribution over plausible parse trees of a given sequence to maximize E qθ(T |x) [log pφ(y|T, x)] . Both the conditional distribution over parse trees qθ(T|x) and the classifier pφ(y|T, x) are parameterized using neural networks. In this model, a parse tree T qθ(T|x) for a given sentence is sampled bottom-up by successively combining the embeddings of two tokens that appear in a given sequence until a single embedding for the entire sequence remains. This is then used for performing the subsequent classification. Because it is computationally infeasible to marginalize over all trees, Choi et al. (2017) rely on the ST-GS estimator for training. We compare this estimator against our estimator GR-MCK with K {10, 100, 1000}. We consider temperatures τ {0.01, 0.1, 1.0} and experiment with shallow and deeper trees by considering sequences of length L up to 10, 25 and 50. All models are trained with stochastic gradient descent with a batch size equal to the maximum L. Because we are interested in a controlled setting to investigate the effect of τ and K, our experimental set-up is significantly simpler than elsewhere (e.g., Havrylov et al., 2019). We give details and highlight important differences in Appendix D.1. Our estimator facilitates training at lower temperatures and achieves better final test set accuracy than ST-GS (Table 1). Increasing K improves the performance at low temperatures, where the differences between the estimators are most pronounced. Overall, across all temperatures this results in modest improvements, particularly for shallow trees and small batch sizes. We also find evidence for diminishing returns: The differences between ST-GS and GR-MC10 are larger than between GR-MC100 or GR-MC1000, suggesting that our estimator is effective even with few MC samples. 5.4 GENERATIVE MODELING WITH DISCRETE VARIATIONAL AUTO-ENCODERS Finally, we train variational auto-encoders (Kingma & Welling, 2013; Rezende et al., 2014) with discrete latent random variables on the MNIST dataset of handwritten digits (Le Cun & Cortes, 2010). We used the fixed binarization of (Salakhutdinov & Murray, 2008) and the standard split into train, validation and test sets. Our objective is to maximize the following variational lower bound on the log-likelihood, log p(x) > E qθ(Di|x) where x denotes the input image and Di qθ(Di|x) denotes a vector of discrete latent random variables. This objective takes a form in equation (1). For training, the bound is approximated using Published as a conference paper at ICLR 2021 (a) log10(var) vs Iterations (b) log10(mse) vs τ (c) Iterations vs ELBO Figure 2: Our estimator (GR-MCK) effectively reduces the variance over the entire training trajectory (2a), achieves a lower mean squared error at a lower temperature (2b) and converges faster than ST and ST-GS on the discrete VAE objective (2c). Log10-trace of the covariance matrix over a training trajectory (2a) and log10-MSE (2b) at different temperatures during training, average number of iterations and standard error to reach various thresholds of the objective on the validation set (2c). only a single sample (M = 1). For final validation and testing, we use 5000 samples (M = 5000). Both the generative model pφ(x, D) and the variational distributions qθ(D|x) were parameterized using neural networks. We experiment with different batch sizes and discrete random variables of arities in {2, 4, 8, 16} as in Maddison et al. (2017). To facilitate comparisons, we do not alter the total dimension of the latent space and train all models for 50,000 iterations using stochastic gradient descent with momentum. Hyperparameters are optimised for each estimator using random search (Bergstra & Bengio, 2012) over twenty independent runs. More details are given in Appendix D.2. Our estimator effectively reduces the variance over the entire training trajectory (Figure 2a). Even a small number of MC samples (K = 10) results in sizable variance reductions. The variance reduction compares favorably to the magnitude of the minibatch variance (Appendix E). Empirically, we find that lower temperatures tend to reduce bias. Our estimator facilitates training at lower temperatures and thus features a lower MSE (Figure 2b). During training our estimator can trade off bias and variance to improve the gradient estimation. Empirically, we observed that on this task, the best models using ST-GS trained at an average temperature of 0.65, while the best models using GR-MC1000 trained at an average temperature of 0.35. This is interesting, because it indicates that our estimator may make the use of temperature annealing during training more effective. We find lower variance gradient estimates improve convergence of the objective (Figure 2c). GR-MC1000 reaches various performance thresholds on the validation set with reliably fewer iterations than ST or ST-GS. This effect is observable at different arities and persistent over the entire training trajectory. For final test set performance, our estimator outperforms REINFORCE and all other straight-through estimators (Table 2). The improvements over ST-GS extend up to two nats (for batch size 20, 16-ary) at small batch sizes and are more modest at large batch sizes as expected (also see Appendix E). This confirms that our estimator might be particularly effective in settings, where training at high batch sizes is prohibitively expensive. The improvements from increasing the number of MC samples tend to saturate at K = 100 on this task. Further, our results suggest that relaxed estimators may be preferred (if they can be used) for discrete random variables of smaller arity. For example, the GS estimator outperforms all straight-through estimators for binary variables for both batch sizes. For large arities however, we find that straight-through estimators can perform competitively: Our estimator GR-MC1000 achieves the best performance overall and outperforms the GS estimator for 16-ary variables. 6 CONCLUSION We introduced the Gumbel-Rao estimator, a new single-evaluation non-relaxing gradient estimator for models with discrete random variables. Our estimator is a Rao-Blackwellization of the state-ofthe-art straight-through Gumbel-Softmax estimator. It enjoys lower variance and can be implemented efficiently using Monte Carlo methods. In particular and in contrast to most other work, it does not require additional function evaluations. Empirically, our estimator improved final test set performance in an unsupervised parsing task and on a variational auto-encoder loss. It accelerated convergence on the objective and compared favorably to other standard gradient estimators. Even though the Published as a conference paper at ICLR 2021 Table 2: Our estimator (GR-MCK) outperforms other straight-through estimators for discrete-latentspace VAE objectives on the MNIST dataset and is competitive with the Gumbel-Softmax (GS) at large arities. Best bound on the test negative log-likelihood selected on the validation set. Best straight-through estimator in bold, best estimator in italics. Lower is better. BINARY 4-ARY 8-ARY 16-ARY ESTIMATOR B = 20 B = 200 B = 20 B = 200 B = 20 B = 200 B = 20 B = 200 GS 98.2 96.4 95.7 93.8 95.5 92.3 96.8 94.3 REINFORCE 202.6 121.4 173.7 122.2 203.9 124.9 169.4 129.5 ST 105.5 103.1 106.2 104.5 107.2 105.1 108.2 104.5 FOUST 101.5 97.8 - - - - - - ST-GS 100.7 97.1 99.1 93.7 98.0 92.8 98.8 92.6 GR-MC10 100.7 97.4 97.8 93.8 97.4 93.1 97.9 92.4 GR-MC100 100.6 96.8 97.5 94.0 96.8 92.2 97.3 92.4 GR-MC1000 100.5 97.0 97.6 93.5 96.5 92.5 96.8 92.2 gains were sometimes modest, they were persistent and particularly pronounced when models must be trained at low temperatures or with small batch sizes. We expect that our estimator will be most effective in such settings and that further gains may be uncovered when combining our Rao Blackwellisation scheme with an annealing schedule for the temperature. Finally, we hope that our work inspires further exploration of the use of Rao-Blackwellisation for gradient estimation. ACKNOWLEDGEMENTS MBP gratefully acknowledges support from the Max Planck ETH Center for Learning Systems. CJM is grateful for the support of the James D. Wolfensohn Fund at the Institute of Advanced Studies in Princeton, NJ. Resources used in preparing this research were provided, in part, by the Sustainable Chemical Processes through Catalysis (Suchcat) National Center of Competence in Research (NCCR), the Province of Ontario, the Government of Canada through CIFAR, and companies sponsoring the Vector Institute. Yoshua Bengio, Nicholas L eonard, and Aaron Courville. Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. ar Xiv e-prints, art. ar Xiv:1308.3432, Aug 2013. James Bergstra and Yoshua Bengio. Random search for hyper-parameter optimization. Journal of machine learning research, 13(Feb):281 305, 2012. David Blackwell. Conditional expectation and unbiased sequential estimation. Ann. Math. Statist., 18(1):105 110, 03 1947. doi: 10.1214/aoms/1177730497. Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. ar Xiv preprint ar Xiv:1509.00519, 2015. Jihun Choi, Kang Min Yoo, and Sang-goo Lee. Unsupervised learning of task-specific tree structures with tree-lstms. In Co RR, 2017. Junyoung Chung, Sungjin Ahn, and Yoshua Bengio. Hierarchical multiscale recurrent neural networks. In International Conference on Learning Representations, 2017. Andrew Davis and Itamar Arel. Low-rank approximations for conditional feedforward computation in deep neural networks. ar Xiv preprint ar Xiv:1312.4461, 2013. Published as a conference paper at ICLR 2021 Andreea Gane, Tamir Hazan, and Tommi Jaakkola. Learning with maximum a-posteriori perturbation models. In Artificial Intelligence and Statistics, pp. 247 256. PMLR, 2014. Peter W Glynn. Likelihood ratio gradient estimation for stochastic systems. Communications of the ACM, 33(10):75 84, 1990. Will Grathwohl, Dami Choi, Yuhuai Wu, Geoffrey Roeder, and David Duvenaud. Backpropagation through the void: Optimizing control variates for black-box gradient estimation. In International Conference on Learning Representations, 2018. Shixiang Gu, Sergey Levine, Ilya Sutskever, and Andriy Mnih. Muprop: Unbiased backpropagation for stochastic neural networks. In International Conference on Learning Representations, 2016. Serhii Havrylov, Germ an Kruszewski, and Armand Joulin. Cooperative learning of disjoint syntax and semantics. ar Xiv preprint ar Xiv:1902.09393, 2019. Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparametrization with Gumble-Softmax. In International Conference on Learning Representations (ICLR 2017), 2017. Diederik P Kingma and Max Welling. Auto-Encoding Variational Bayes. ar Xiv e-prints, art. ar Xiv:1312.6114, Dec 2013. Wouter Kool, Herke van Hoof, and Max Welling. Estimating gradients for discrete random variables by sampling without replacement. In International Conference on Learning Representations, 2020. Dirk P Kroese, Thomas Taimre, and Zdravko I Botev. Handbook of monte carlo methods, volume 706. John Wiley & Sons, 2013. Yann Le Cun and Corinna Cortes. MNIST handwritten digit database. URL http://yann. lecun. com/exdb/mnist, 2010. URL http://yann.lecun.com/exdb/mnist/. Runjing Liu, Jeffrey Regier, Nilesh Tripuraneni, Michael I Jordan, and Jon Mc Auliffe. Raoblackwellized stochastic gradients for discrete distributions. ar Xiv preprint ar Xiv:1810.04777, 2018. Ryan Lowe, Yi Wu, Aviv Tamar, Jean Harb, Pieter Abbeel, and Igor Mordatch. Multi-agent actorcritic for mixed cooperative-competitive environments. Co RR, abs/1706.02275, 2017. URL http://arxiv.org/abs/1706.02275. Chris J. Maddison. A Poisson process model for Monte Carlo. In Tamir Hazan, George Papandreou, and Daniel Tarlow (eds.), Perturbation, Optimization, and Statistics. MIT Press, 2016. Chris J. Maddison, Daniel Tarlow, and Tom Minka. A* Sampling. In Advances in Neural Information Processing Systems 27, 2014. Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. In International Conference on Learning Representations, 2017. Andr e F. T. Martins, Tsvetomila Mihaylova, Nikita Nangia, and Vlad Niculae. Latent structure models for natural language processing. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics: Tutorial Abstracts, pp. 1 5, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-4001. URL https: //www.aclweb.org/anthology/P19-4001. Andriy Mnih and Karol Gregor. Neural variational inference and learning in belief networks. In Proceedings of the 31st International Conference on International Conference on Machine Learning-Volume 32, pp. II 1791, 2014. Andriy Mnih and Danilo J Rezende. Variational inference for monte carlo objectives. In Proceedings of the 33rd International Conference on International Conference on Machine Learning-Volume 48, pp. 2188 2196, 2016. Igor Mordatch and Pieter Abbeel. Emergence of grounded compositional language in multi-agent populations. Co RR, abs/1703.04908, 2017. URL http://arxiv.org/abs/1703.04908. Published as a conference paper at ICLR 2021 Nikita Nangia and Samuel R. Bowman. Listops: A diagnostic dataset for latent tree learning, 2018. Max Paulus, Dami Choi, Daniel Tarlow, Andreas Krause, and Chris J Maddison. Gradient estimation with stochastic softmax tricks. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin (eds.), Advances in Neural Information Processing Systems, volume 33, pp. 5691 5704. Curran Associates, Inc., 2020. URL https://proceedings.neurips.cc/paper/2020/file/ 3df80af53dce8435cf9ad6c3e7a403fd-Paper.pdf. Adeel Pervez, Taco Cohen, and Efstratios Gavves. Low Bias Low Variance Gradient Estimates for Hierarchical Boolean Stochastic Networks. In ICML, 2020. URL https://proceedings. icml.cc/static/paper_files/icml/2020/3626-Paper.pdf. C. Radhakrishna Rao. Information and the accuracy attainable in the estimation of statistical parameters. In Samuel Kotz and Norman L. Johnson (eds.), Breakthroughs in Statistics: Foundations and Basic Theory, pp. 235 247, New York, NY, 1992. Springer New York. Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, pp. 1278 1286, 2014. Ruslan Salakhutdinov and Iain Murray. On the quantitative analysis of deep belief networks. In Proceedings of the 25th international conference on Machine learning, pp. 872 879, 2008. John Schulman, Nicolas Heess, Theophane Weber, and Pieter Abbeel. Gradient estimation using stochastic computation graphs. In Advances in Neural Information Processing Systems, pp. 3528 3536, 2015. George Tucker, Andriy Mnih, Chris J. Maddison, and Jascha Sohl-Dickstein. REBAR : Low-variance, unbiased gradient estimates for discrete latent variable models. In Neural Information Processing Systems, 2017. Tim Vieira. Estimating means in a finite universe, 2017. URL https://timvieira. github. io/blog/post/2017/07/03/estimating-means-in-a-finite-universe, 2017. Th eophane Weber, Nicolas Heess, Lars Buesing, and David Silver. Credit assignment techniques in stochastic computation graphs. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 2650 2660, 2019. Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229 256, 1992. Zichao Yang, Zhiting Hu, Ruslan Salakhutdinov, and Taylor Berg-Kirkpatrick. Improved variational autoencoders for text modeling using dilated convolutions. Co RR, abs/1702.08139, 2017. URL http://arxiv.org/abs/1702.08139.