# scaling_structured_inference_with_randomization__bff68542.pdf Scaling Structured Inference with Randomization Yao Fu 1 John P. Cunningham 2 3 Mirella Lapata 1 Deep discrete structured models have seen considerable progress recently, but traditional inference using dynamic programming (DP) typically works with a small number of states (less than hundreds), which severely limits model capacity. At the same time, across machine learning, there is a recent trend of using randomized truncation techniques to accelerate computations involving large sums. Here, we propose a family of randomized dynamic programming (RDP) algorithms for scaling structured models to tens of thousands of latent states. Our method is widely applicable to classical DP-based inference (partition function, marginal, reparameterization, entropy) and different graph structures (chains, trees, and more general hypergraphs). It is also compatible with automatic differentiation: it can be integrated with neural networks seamlessly and learned with gradient-based optimizers. Our core technique approximates the sum-product by restricting and reweighting DP on a small subset of nodes, which reduces computation by orders of magnitude. We further achieve low bias and variance via Rao Blackwellization and importance sampling. Experiments over different graphs demonstrate the accuracy and efficiency of our approach. RDP can also be used to learn a structured variational autoencoder with a scaled inference network which outperforms baselines in terms of test likelihood and successfully prevents collapse. 1. Introduction Deep discrete structured models (Martins et al., 2019; Rush, 2020) have enjoyed great progress recently, improving per- 1School of Informatics, University of Edinburgh 2Statistics Department, Columbia University 3Zuckerman Institute, Columbia University. Correspondence to: Yao Fu , John P. Cunningham , Mirella Lapata . Proceedings of the 39 th International Conference on Machine Learning, Baltimore, Maryland, USA, PMLR 162, 2022. Copyright 2022 by the author(s). formance and interpretability in a range of tasks including sequence tagging (Ma & Hovy, 2016), parsing (Zhang et al., 2020; Yang et al., 2021), and text generation (Wiseman et al., 2018). However, their capacity is limited by issues of scaling (Sun et al., 2019; Chiu & Rush, 2020; Yang et al., 2021; Chiu et al., 2021). Traditional dynamic programming based inference for exponential families has limited scalability with large combinatorial spaces. When integrating exponential families with neural networks (e.g., a VAE with a CRF inference network, detailed later), the small latent combinatorial space severely restricts model capacity. Existing work has already observed improved performance by scaling certain types of structures (Yang et al., 2021; Li et al., 2020; Chiu & Rush, 2020), and researchers are eager to know if there are general techniques for scaling classical structured models. Challenges in scaling structured models primarily come from memory complexity. For example, consider linearchain CRFs (Sutton & Mc Callum, 2006), the classical sequence model that uses the Forward algorithm, a dynamic programming algorithm, for computing the partition function exactly. This algorithm requires O(TN 2) computation where N is the number of latent states and T is the length of the sequence. It is precisely the N 2 term that is problematic in terms of memory and computation. This limitation is more severe under automatic differentiation (AD) frameworks as all intermediate DP computations are stored for gradient construction. Generally, DP-based inference algorithms are not optimized for modern computational hardware like GPUs and typically work under small-data regimes, with N in the range [10, 100] (Ma & Hovy, 2016; Wiseman et al., 2018). With larger N, inference becomes intractable since the computation graph does not easily fit into GPU memory (Sun et al., 2019). Aligning with a recent trend of exploiting randomization techniques for machine learning problems (Oktay et al., 2020; Potapczynski et al., 2021; Beatson & Adams, 2019), this work proposes a randomization framework for scaling structured models, which encompasses a family of randomized dynamic programming algorithms with a wide coverage of different structures and inference (Table 1). Within our randomization framework, instead of summing over all possible combinations of latent states, we only sum over paths with the most probable states and sample a subset of less Scaling Structured Inference with Randomization Table 1. Dynamic programming algorithms for different inference over different graphs. Our randomization technique covers a spectrum of classical DP algorithms on different graphs. Scaled algorithms shown in red are randomized in this work. Graph Model Partition & Marginal Entropy Sampling & Reparameterization Chains HMM, CRF, Semi-Markov Forward-Backward Backward Entropy Gumbel Sampling (Fu et al., 2020) Hypertrees PCFG, Dependency CRF Inside-Outside Outside Entropy Stochastic Softmax (Paulus et al., 2020) General graph General Exponential Family Sum-Product Bethe Entropy Stochastic Softmax (Paulus et al., 2020) likely paths to correct the bias according to a reasonable proposal. Since we only calculate the chosen paths, memory consumption can be reduced to a reasonably small budget. We thus recast the computation challenge into a tradeoff between memory budget, proposal accuracy, and estimation error. In practice, we show RDP scales existing models by two orders of magnitude with memory complexity as small as one percent. In addition to the significantly increased scale, we highlight the following advantages of RDP: (1) applicability to different structures (chains, trees, and hypergraphs) and inference operations (partition function, marginal, reparameterization, and entropy); (2) compatibility with automatic differentiation and existing efficient libraries (like Torch Struct in Rush, 2020); and (3) statistically principled controllability of bias and variance. As a concrete application, we show that RDP can be used for learning a structured VAE with a scaled inference network. In experiments, we first demonstrate that RDP algorithms estimate partition function and entropy for chains and hypertrees with lower mean square errors than baselines. Then, we show their joint effectiveness for learning the scaled VAE. RDP outperforms baselines in terms of test likelihood and successfully prevents posterior collapse. Our implementation is at https://github.com/Franx Yao/RDP. 2. Background and Preliminaries Problem Statement We start with a widely-used structured VAE framework. Let y be an observed variable. Let X be any latent structure (sequences of latent tags, parse trees, or general latent graphs, see Table 1) that generates y. Let pψ denote the generative model, and qθ the inference model. We optimize: L = Eqθ(X|y)[log pψ(X, y) log qθ(X|y)] (1) where the inference model qθ takes the form of a discrete undirected exponential family (e.g., linear-chain CRFs in Fu et al., 2020). Successful applications of this framework include sequence tagging (Mensch & Blondel, 2018), text generation (Li & Rush, 2020), constituency parsing (Kim et al., 2019), dependency parsing (Corro & Titov, 2018), latent tree induction (Paulus et al., 2020), and so on. Our goal is to learn this model with a scaled latent space (e.g., a CRF encoder with tens of thousands of latent states). Sum-Product Recap Learning models in Eq. 1 usually requires classical sum-product inference (and its variants) of qθ, which we now review. Suppose qθ is in standard overparameterization of a discrete exponential family (Wainwright & Jordan, 2008): s V ϕs(xs) + X (s,t) E ϕst(xs, xt) where X = [X1, ..., XM] is a random vector of nodes in a graphical model with each node taking discrete values Xi {1, 2, ..., N}, N is the number of states, ϕst is the edge potential, and ϕs is the node potential. Here, we use a general notation of nodes V and edges E. We discuss specific structures later. Suppose we want to compute its marginals and log partition function. The solution is the sum-product algorithm that recursively updates the message at each edge: n ϕst(xs, x t)ϕt(x t) Y u V s t µut(x t) o (3) where µts(xs) denotes the message from node t to node s evaluated at Xs = xs, V s t denotes the set of neighbor nodes of t except s. Upon convergence, it gives us the Bethe approximation (if the graph is loopy) of the marginal, partition function, and entropy (as functions of the edge marginals µts). These quantities are then used in gradient based updates of the potentials ϕ (detailed in Sec. 4). When the underlying graph is a tree, this approach becomes exact and convergence is linear in the number of edges. Challenges in Scaling The challenge is how to scale the sum-product computation for qθ. Specifically, gradientbased learning requires three types of inference: (a) partition estimation (for maximizing likelihood); (b) reparameterized sampling (for gradient estimation); and (c) entropy estimation (for regularization). Existing sum-product variants (Table 1) provide exact solutions, but only for a small latent space (e.g., a linear-chain CRF with state number smaller than 100). Since the complexity of DP-based inference is usually at least quadratic to the size of the latent Scaling Structured Inference with Randomization states, it would induce memory overflow if we wanted to scale it to tens of thousands. Restrictions from Automatic Differentiation At first sight, one may think the memory requirement for some algorithms is not so large. For example, the memory complexity of a batched forward sum-product (on chains) is O(BTN 2). Consider batch size B = 10, sequence length T = 100, number of states N = 1000, then the complexity is O(109), which seems acceptable, especially given multiple implementation tricks (e.g., dropping the intermediate variables). However, this is not the case under automatic differentiation, primarily because AD requires storing all intermediate computations1 for building the adjacent gradient graph. This not only invalidates tricks like dropping intermediate variables (because they should be stored) but also multiplies memory complexity when building the adjacent gradient graph (Eisner, 2016). This problem is more severe if we want to compute higher-order gradients, or if the underlying DP has higher-order complexity (e.g., the Inside algorithm of cubic complexity O(T 2N 3)). Since gradient-based optimization requires AD-compatibility, scaling techniques are under similar restrictions, namely storing all computations to construct the adjacent gradient graph. Previous Efforts Certainly there exist several excellent scaling techniques for structured models, though most come with some intrinsic limitations that we hope to alleviate. In addition to AD-compatibility restrictions, many existing techniques either require additional assumptions (e.g., sparsity in Lavergne et al., 2010; Sokolovska et al., 2010; Correia et al., 2020, pre-clustering in Chiu & Rush, 2020, or low-rank in Chiu et al., 2021), rely on handcrafted heuristics for bias correction (Jeong et al., 2009), or cannot be easily adapted to modern GPUs with tensorization and parallelization (Klein & Manning, 2003). As a result, existing methods apply to a limited range of models: Chiu & Rush (2020) only consider chains and Yang et al. (2021) only consider probabilistic context-free grammars (PCFGs). One notably promising direction comes from Sun et al. (2019), where they consider top K computation and drop all low probability states. Pillutla et al. (2018) also considered top K inference in a smoothed setting (where they have more theoretical analytics but are less about memory efficiency) and their top K inference also points to efficiency improvements. While intuitively sensible (indeed we build on this idea here), their deterministic truncation induces severe bias (later shown in our experiments). As we will discuss, this bias can be effectively mitigated by randomization. 1More recently, there are works improving the memoryefficiency of AD like Rajbhandari et al. (2020) by off-loading intermediate variables to CPUs. Yet their application on sumproduct inference requires modification of the AD library, which raises significant engineering challenges. Randomization of Sums in Machine Learning Randomization is a long-standing technique in machine learning (Mahoney, 2011; Halko et al., 2011) and has been applied to dynamic programs before the advent of deep learning (Bouchard-cˆot e et al., 2009; Blunsom & Cohn, 2010). Notably, works like Koller et al. (1999); Ihler & Mc Allester (2009) also consider sampling techniques for sum-product inference. However, since these methods are proposed before deep learning, their differentiability remain untested. More recently, randomization has been used in the context of deep learning, including density estimation (Rhodes et al., 2020), Gaussian processes (Potapczynski et al., 2021), automatic differentiation (Oktay et al., 2020), gradient estimation (Correia et al., 2020), and optimization (Beatson & Adams, 2019). The foundation of our randomized DP also lies in speeding up summation by randomization. To see the simplest case, consider the sum of a sorted list a of positive numbers: S = PN i=1 ai. This requires N 1 additions, which could be expensive when N is large. Suppose one would like to reduce the number of summands to K, Liu et al. (2019) discusses the following sum-and-sample estimator for gradient estimation: aδj qδj (4) where K1 + K2 = K and δj q = [q K1+1, ..., q N], q is a proposal distribution upon the tail summands [a K1+1, ..., a N]. This estimator is visualized in Fig. 1A. One can show that it is unbiased: Eq[ ˆS] = S, irrespective of how we choose the proposal q. The oracle proposal q is the normalized tail summands: q i = ai/ PN j=K1+1 aj, under which the estimate becomes exact: ˆS S. Note that the bias is corrected by dividing the proposal probability. 3. Scaling Inference with Randomization In this section, we introduce our randomized dynamic programming algorithms. We first introduce a generic randomized sum-product framework for approximating the partition function of any graph structure. Then we instantiate this framework on two classical structures, chains and hypertrees, respectively. When running the randomized sum-product in a left-to-right (or bottom-up) order, we get a randomized version of the classicial Forward (or Inside) algorithm for estimating the partition function for chains like HMMs and Linear-chain CRFs (or hypertrees like PCFGs and Dependency CRFs in Kim et al., 2019). We call these two algorithms first-order RDPs because they run the computation graph in one pass. Next, we generalize RDP to two more inference operations (beyond partition estimation): entropy estimation (for the purposes of regularized training) and reparameterization (for gradient-based optimization). We collectively use the term second-order RDPs to refer to Scaling Structured Inference with Randomization Top K summand Sampled summand Gap to oracle Dropped summand Top K state Sampled state A. Randomized Summation of a sorted list Top K1 terms Sampled K2 terms B. Randomized Forward, Entropy and Reparam. at 1 at DP path C. Randomized Inside a(i, m, ) a(m + 1,j, ) DP path Grad. path DP path Grad. path Figure 1. Scaling inference by randomization. (A): randomized summation (Eq.4). (B): Randomized Forward (Alg. 1) recursively applies randomized summation at each DP step. Memory reduction is achieved by restricting computation on edges linking the sampled nodes. Second-order randomized DPs (Entropy and Reparameterization, Alg. 3 and 4) reuse this graph and share the same path as the gradients. (C): Randomized Inside (Alg. 2) applies randomized summation twice at each DP step. All algorithms in our randomized DP family are compatible to automatic differentiation, the direction of the gradient is shown by red dashed arrows. entropy and reparameterization algorithms because they call first-order RDPs as a subroutine and run a second pass of computation upon the first-order graphs. We will compose these algorithms into a structured VAE in Sec. 4. 3.1. Framework: Randomized Sum-Product Our key technique for scaling the sum-product is a randomized message estimator. Specifically, for each node Xt, given a pre-constructed proposal qt = [qt(1), ..., qt(N)] (we discuss how to construct a correlated proposal later because it depends on the specific neural parameterization), we retrieve the top K1 index {σt,i}K1 i=1 from qt, and get a K2-sized sample {δt,i}K2 i=1 from the rest of qt: [σt,1, ..., σt,K1, ..., σt,N] = arg sort{qt(i)}N i=1 (5) δt,j i.i.d. Categorical{qt(σt,K1+1), ..., qt(σt,N)} (6) ΩK1 t = {σt,i}K1 i=1 ΩK2 t = {δt,j}K2 j=1 (7) Then, we substitute the full index {1, ..., N} with the the top ΩK1 t and the sampled index ΩK2 t : n ϕst(xs, σt)ϕt(σt) Y u V s t ˆµut(σt) o + n 1 K2qt(δt)ϕst(xs, δt)ϕt(δt) Y u V s t ˆµut(δt) o (8) where the oracle proposal is proportional to the actual summands (Eq. 3). Here indices are pre-computed outside DP. This means the computation of Eq. 5 7 can be moved outside the AD engine (e.g., in Pytorch, under the no grad statement) to further save GPU memory usage. We now show that estimator Eq. 8 can be interpreted as a combination of Rao-Blackwellization (RB) and importance sampling (IS). Firstly, RB says that a function J(X, Y ) depending on two random variables X, Y has larger variance than its conditional expectation ˆJ(X) = E[J(X, Y )|X] (Ranganath et al., 2014): V[ ˆJ(X)] = V[J(X, Y )] E[(J(X, Y ) ˆJ(X))2] (9) V[J(X, Y )] (10) This is because Y has been integrated out and the source of randomness is reduced to X. Now let δ1, δ2 be uniform random indices taking values from ΩK1 t and ΩK2 t , respectively. Consider a simple unbiased estimate S for the message µst: mst(i) = ϕst(xs, i)ϕt(i) Y u V s t µut(i) (11) S(δ1, δ2) = K1mst(δ1) + (N K1)mst(δ2) (12) The conditional expectation ˆS(δ2) = E[S(δ1, δ2)|δ2] is: mst(σt) + (N K1)mst(δ2) (13) Plugging S(δ1, δ2), ˆS(δ2) in Eq. 10, we get variance reduction: V[ ˆS(δ2)] V[S(δ1, δ2)]. Note that the variance will becomes smaller as K1 becomes larger. Now change δ2 in Eq. 13 to from the proposal in Eq. 6 then take expectation: Eδ2 Uniform(ΩK2 t )[(N K1)mst(δ2)] (14) = Eδ2 qt[ 1 qt(δ2)(N K1)(N K1)mst(δ2)] (15) = Eδ2 qt[ 1 qt(δ2)mst(δ2)] (16) This is an instance of importance sampling where the importance weight is 1 qt(δ2)(N K1). Variance can be reduced if qt is correlated with the summands (Eq. 11). Finally, combining Eq. 16 and 13 and increasing the number of sampled Scaling Structured Inference with Randomization Algorithm 1 Randomized Forward Input: potentials ϕ(xt 1, xt, yt), top K1 index set ΩK1 t , sampled K2 index set ΩK2 t Initialize α1(i) = ϕ(x0 = , x1 = i, yt) For t = 2 to T, compute recursion: ˆαt 1(σ)ϕ(σ, i, yt)+ 1 K2qt(δ) ˆαt 1(δ)ϕ(δ, i, yt) (17) Return ˆZ = P σ ΩK1 T ˆαT (σ) + P 1 K2q T (δ) ˆαT (δ) and index δ2 to K2 will recover our full message estimator in Eq. 8. 3.2. First-Order Randomized DP Now we instantiate this general randomized message passing principle for chains (on which sum-product becomes the Forward algorithm) and tree-structured hypergraphs (on which sum-product becomes the Inside algorithm). When the number of states N is large, exact sum-product requires large GPU memory when implemented with AD libraries like Pytorch. This is where randomization comes to rescue. 3.2.1. RANDOMIZED FORWARD Algorithm 1 shows our randomized Forward algorithm for approximating the partition function of chain-structured graphs. The core recursion in Eq. 17 estimates the alpha variable ˆαt(i) as the sum of all possible sequences up to step t at state i. It corresponds to the Eq. 8 applied to chains (Fig. 1B). Note how the term K2qt(δ) is divided in Eq. 17 to correct the estimation bias. Also note that all computation in Alg. 1 are differentiable w.r.t. the factors ϕ. We can recover the classical Forward algorithm by changing the chozen index set ΩK t to the full index [1, .., N]. In Appendix A, we prove the unbiasedness by induction. As discussed in the above section, we reduce the variance of the randomized Forward by (1) Rao-Blackwellization (increasing K1 to reduce randomness); and (2) importance sampling (to construct a proposal correlated to the actual summands). The variance comes from the gap between the proposal and the oracle (only accessible with a full DP), as shown in the green bars in Fig 1B. Variance is also closely related to how long-tailed the underlying distribution is: the longer the tail, the more effective Rao-Blackwellization will be. More detailed variance analysis is presented in Appendix B. In practice, we implement the algorithm in log Algorithm 2 Randomized Inside Input: potentials ϕ(i, j, k), top K1 index set ΩK1 i,j , sampled K2 index set ΩK2 i,j Initialize α(i, i, k) = ϕ(i, i, k). For l = 1 to T 1, let j = i + l, compute recursion: ˆα(i, j, k) = ϕ(i, j, k) σ1 ΩK1 i,m,σ2 ΩK1 m+1,j ˆα(i, m, σ1) ˆα(m + 1, j, σ2)+ δ1 ΩK2 i,m,δ2 ΩK2 m+1,j 1 K2qi,m(δ1) ˆα(i, m, δ1) 1 K2qm+1,j(δ2) ˆα(m + 1, j, δ2) (18) Return ˆZ = P σ ΩK1 1,T ˆα(1, T, σ)+ P 1 K2q T (δ) ˆα(1, T, δ) and {ˆαi,j}, i, j {1, ..., T} space (for numerical stability)2. which has two implications: (a). the estimate becomes a lower bound due to Jensen s inequality; (b). the variance is exponentially reduced by the log( ) function (in a rather trivial way). Essentially, the Sampled Forward restricts the DP computation from the full graph to a subgraph with chosen nodes (ΩK1 t and ΩK2 t for all t), quadratically reducing memory complexity from O(TN 2) to O(TK2). Since all computations in Alg. 1 are differentiable, one could directly compute gradients of the estimated partition function with any AD library, thus enabling gradient-based optimization. Backpropagation shares the same DP graph as the Randomized Forward, but reverses its direction (Fig. 1B). 3.2.2. RANDOMIZED INSIDE We now turn to our Randomized Inside algorithm for approximating the partition function of tree-structured hypergraphs (Alg. 2). It recursively estimates the inside variables ˆα(i, j, k) which sum over all possible tree branchings (index m in Eq. 18) and chosen state combinations (index σ1, σ2, δ1, δ2, Fig 1C). Index i, j denotes a subtree spanning from location i to j in a given sequence, and k denotes the state. Different from the Forward case, this algorithm computes the product of two randomized sums that represent two subtrees, i.e., (i, m) and (m + 1, j). The proposal is constructed for each subtree. i.e., qi,m and qm+1,j, and are both divided in Eq. 18 for correcting the estimation bias. 2Common machine learning practice works with log probabilities, thus log partition functions. Yet one can also implement Alg.1 in the original space for unbiasedness. Scaling Structured Inference with Randomization Algorithm 3 Randomized Entropy DP Input: potentials ϕ(xt 1, xt, yt), top K1 index set ΩK1 t , sampled K2 index set ΩK2 t Initialize H1(i) = 0; call Randomized Forward to get ˆZ, ˆα For t = 1 to T 1, compute recursion: ˆpt(i, j) = ϕ(i, j, xt)ˆαt(i)/ˆαt+1(j) (19) ˆHt+1(j) = X ˆpt(σ, j)[ ˆHt(σ) log ˆpt(σ, j)]+ 1 K2qt(δ) ˆpt(δ, j)[ ˆHt(δ) log ˆpt(δ, j)] (20) Return ˆH where: ˆp T (i) = ˆαT (i)/ ˆZ (21) ˆp T (σ)[ ˆHT (σ) log ˆp T (σ)]+ 1 K2q T (δ) ˆp T (δ)[ ˆHT (δ) log ˆp T (δ)] (22) The Randomized Inside restricts computation to the sampled states of subtrees, reducing the complexity from O(T 2N 3) to O(T 2K3). The output partition function is again differentiable and the gradients reverse the computation graph. Similar to the Forward case, unbiasedness can be proved by induction. The variance can be decreased by increasing K1 to reduce randomness and constructing a correlated proposal. In Fig. 1C, green bars represent gaps to the oracle proposal, which is a major source of estimation error. 3.3. Second-Order Randomized DP We now generalize RDP from partition function estimation to two more inference operations: entropy and reparameterized sampling. When composing graphical models with neural networks, entropy is usually required for regularization, and reparameterized sampling is required for Monte Carlo gradient estimation (Mohamed et al., 2020). Again, we call our methods second-order because they reuse the computational graph and intermediate outputs of first-order RDPs. We focus on chain structures (HMMs and linearchain CRFs) for simplicity. 3.3.1. RANDOMIZED ENTROPY DP Algorithm 3 shows the randomized entropy DP3. It reuses the chosen index ΩK t (thus the computation graph) and the intermediate variables ˆαt of the randomized Forward, and recursively estimates the conditional entropy ˆHt(j) which 3See Mann & Mc Callum (2007); Li & Eisner (2009) for more details on deriving entropy DP. Algorithm 4 Randomized Gumbel Backward Sampling Input: potentials ϕ(xt 1, xt, yt), top K1 index set ΩK1 t , sampled K2 index set ΩK2 t , gumble noise gt(i) Gumbel(0, 1) Initialize: call Randomized Forward to get ˆZ, ˆα, then: ˆp T (i) = ˆαT (i)/ ˆZ, i ΩK1 T ΩK2 T (23) x T = softmaxi(log ˆp T (i) + g T (i)) (24) ˆx T = argmaxi x T (i) (25) For t = T 1 to 1, compute recursion: ˆpt(i, j) = ϕ(i, j, yt)ˆαt(i)/ˆαt+1(j) i ΩK1 t ΩK2 t , j ΩK1 t+1 ΩK2 t+1 (26) xt = softmaxi(log ˆpt(i, ˆxt+1) + gt(i)) (27) ˆxt = argmaxi xt(i) (28) Return relaxed sample { xt}T t=1, hard sample {ˆxt}T t=1 represents the entropy of the chain ending at step t, state j. Unbiasedness can be similarly proved by induction. Note that here the estimate is biased because of the log( ) in Eq. 20 (yet in the experiments we show its bias is significantly less than the baselines). Also note that the proposal probability qt(δ) should be divided for bias correction (Eq. 20). Again, all computation is differentiable, which means that the output entropy can be directly differentiated by AD engines. 3.3.2. RANDOMIZED GUMBEL BACKWARD SAMPLING When training VAEs with a structured inference network, one usually requires differentiable samples from the inference network. Randomized Gumbel backward sampling (Alg. 4) provides differentiable relaxed samples from HMMs and Linear-chain CRFs. Our algorithm is based on the recently proposed Gumbel Forward-Filtering Backward Sampling (FFBS) algorithm for reparameterized gradient estimation of CRFs (see more details in Fu et al., 2020), and scales it to CRFs with tens of thousands of states. It reuses DP paths of the randomized Forward and recursively computes hard sample ˆxt and soft sample xt (a relaxed one-hot vector) based on the chosen index ΩK t . When differentiating these soft samples for training structured VAEs, they induce biased but low-variance reparameterized gradients. 4. Scaling Structured VAE with RDP We study a concrete structured VAE example that uses our RDP algorithms for scaling. We focus on the language domain, however our method generalizes to other types of sequences and structures. We will use the randomized Gumbel backward sampling algorithm for gradient estimation and the randomized Entropy DP for regularization. Scaling Structured Inference with Randomization Table 2. Mean square error comparison between RDP algorithms v.s. Top K approximation. D = Dense, I = intermediate, L = Long-tailed distributions. log Z denotes log partition function. Our method outperforms the baseline on all unit cases with significantly less memory. Linear-chain log Z Hypertree log Z Linear-chain Entropy N = 2000 D I L D I L D I L TOPK 20%N 3.874 1.015 0.162 36.127 27.435 21.78 443.7 84.35 8.011 TOPK 50%N 0.990 0.251 0.031 2.842 2.404 2.047 131.8 22.100 1.816 RDP 1%N (ours) 0.146 0.066 0.076 26.331 37.669 48.863 5.925 1.989 0.691 RDP 10%N (ours) 0.067 0.033 0.055 1.193 1.530 1.384 2.116 1.298 0.316 RDP 20%N (ours) 0.046 0.020 0.026 0.445 0.544 0.599 1.326 0.730 0.207 N = 10000 D I L D I L D I L TOPK 20%N 6.395 6.995 6.381 78.632 63.762 43.556 227.36 171.97 141.91 TOPK 50%N 2.134 2.013 1.647 35.929 26.677 17.099 85.063 59.877 46.853 RDP 1%N (ours) 0.078 0.616 0.734 3.376 5.012 7.256 6.450 6.379 4.150 RDP 10%N (ours) 0.024 0.031 0.024 0.299 0.447 0.576 0.513 1.539 0.275 RDP 20%N (ours) 0.004 0.003 0.003 0.148 0.246 0.294 0.144 0.080 0.068 Table 3. Bias-Variance decomposition of Randomized Forward. Dense Long-tail Bias Var. Bias Var. TOPK 20%N -1.968 0 -0.403 0 TOPK 50%N -0.995 0 -0.177 0 RDP 1%N (ours) -0.066 0.141 -0.050 0.074 RDP 10%N (ours) -0.030 0.066 -0.027 0.054 RDP 20%N (ours) -0.013 0.046 -0.003 0.026 Generative Model Let x = [x1, ..., x T ] be a sequence of discrete latent states, and y = [y1, ..., y T ] a sequence of observed words (a word is an observed categorical variable). We consider an autoregressive generative model parameterized by an LSTM decoder: pψ(x, y) = Y t pψ(xt|x