# generative_marginalization_models__c5689718.pdf Generative Marginalization Models Sulin Liu 1 Peter J. Ramadge 1 Ryan P. Adams 1 We introduce marginalization models (MAMs), a new family of generative models for highdimensional discrete data. They offer scalable and flexible generative modeling by explicitly modeling all induced marginal distributions. Marginalization models enable fast approximation of arbitrary marginal probabilities with a single forward pass of the neural network, which overcomes a major limitation of arbitrary marginal inference models, such as any-order autoregressive models. MAMs also address the scalability bottleneck encountered in training any-order generative models for high-dimensional problems under the context of energy-based training, where the goal is to match the learned distribution to a given desired probability (specified by an unnormalized log-probability function such as energy or reward function). We propose scalable methods for learning the marginals, grounded in the concept of marginalization self-consistency . We demonstrate the effectiveness of the proposed model on a variety of discrete data distributions, including images, text, physical systems, and molecules, for maximum likelihood and energy-based training settings. MAMs achieve orders of magnitude speedup in evaluating the marginal probabilities on both settings. For energy-based training tasks, MAMs enable any-order generative modeling of high-dimensional problems beyond the scale of previous methods. Code is available at github.com/Princeton LIPS/Ma M. 1. Introduction Deep generative models have enabled remarkable progress across diverse fields, including image generation, audio synthesis, natural language modeling, and scientific discovery. 1Princeton University. Correspondence to: Sulin Liu . Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). However, there remains a pressing need to better support efficient probabilistic inference for key questions involving marginal probabilities p(x S) and conditional probabilities p(x U|x V), for appropriate subsets S, U, V of the variables. The ability to directly address such quantities is critical in applications such as outlier or machine-generated content detection [59, 48], masked language modeling [15, 85], image inpainting [86], and constrained protein/molecule design [81, 65]. Furthermore, the capacity to conduct such inferences for arbitrary subsets of variables empowers users to leverage the model according to their specific needs and preferences. For instance, in protein design, scientists may want to manually guide the generation of a protein from a user-defined substructure under a particular path over the relevant variables. This requires the generative model to perform arbitrary marginal inferences. Any-order Fixed/Learned-order Arbitrary marginal inference Training NN forward pass (per sample) Likelihood infer. time/memory Sequential discrete generative models Ma M - MLE/EB ARM - MLE ARM - EB GFlow Net-TB - MLE/EB AR Flow - MLE/EB Figure 1. Scalability of sequential discrete generative models. The y-axis unit is # of NN forward passes required. Towards this end, neural autoregressive models (ARMs) [3, 38] have shown great performance in conditional/marginal inference based on the idea of modeling a high-dimensional joint distribution as a factorization of univariate conditionals using the chain rule of probability. Many efforts have been made to scale up ARMs and enable any-order generative modeling under the setting of maximum likelihood estimation (MLE) [38, 78, 24], and great progress has been made in applications such as masked language modeling [85] and image inpainting [24]. However, marginal likelihood evaluation on a sequence of D variables is limited by O(D) neural network passes with the most widely-used modern neural network architectures (e.g., Transformers [80] and U-Nets [62]). This scaling makes it difficult to evaluate Generative Marginalization Models [?][C][=C][C][=C][C][=C][Ring1][=Branch1] [Cl][C][=C][C][=C][C][=C][Ring1][=Branch1] [F][C][=C][C][=C][C][=C][Ring1][=Branch1] pθ( ) pθ( ) pθ( ) + = = + p (0010??) p (1010??) p (?010??) variable is marginalized out x1 Figure 2. Marginalization models (MAMs) enable estimation of any marginal probability with a neural network θ that learns to marginalize out variables. The figure illustrates marginalization of a single variable on bit strings (representing molecules) with two alternatives (versus K in general) for clarity. The bars represent probability masses. likelihoods on long sequences arising in data such as natural language and proteins. In addition to MLE, the setting of energy-based training (EB) has recently received growing interest with its applications in science domains [49, 12, 35]. Instead of empirical data samples, we only have access to an unnormalized (log) probability function (specified by a reward or energy function) that can be evaluated pointwise for the generative model to match. In such settings, ARMs are limited to fixed-order generative modeling and lack scalability in training. The subsampling techniques developed to scale the training of conditionals for MLE are no longer applicable when matching log probabilities in energy-based training (see Section 4.3 for details). To enhance scalability and flexibility in the generative modeling of discrete data, we propose a new family of generative models, marginalization models (MAMs), that directly model the marginal distribution p(x S) for any subset of variables x S in x. Direct access to marginals has two important advantages: 1) significantly speeding up inference for any marginal, and 2) enabling scalable training of any-order generative models under both MLE and EB settings. The unique structure of the model allows it to simultaneously represent the coupled collection of all marginal distributions of a given discrete joint probability mass function. For the model to be valid, it must be consistent with the sum rule of probability, a condition we refer to as marginalization self-consistency (see Figure 2); learning to enforce this with scalable training objectives is one of the key contributions of this work. We show that MAMs can be trained under both maximum likelihood and energy-based training settings with scalable learning objectives. We demonstrate the effectiveness of MAMs in both settings on a variety of discrete data distributions, including binary images, text, physical systems, and molecules. We empirically show that MAMs achieve orders of magnitude speedup in marginal likelihood evaluation. For energy-based training, MAMs are able to scale training of any-order generative models to high-dimensional problems that previous methods fail to achieve. 2. Background We first review two prevalent settings for training a generative model: maximum likelihood estimation and energybased training. Then we introduce autoregressive models. Maximum likelihood (MLE) Given a dataset D = {x(i)}N i=1 drawn i.i.d. from a data distribution p = pdata, we aim to learn the distribution pθ(x) via maximum likelihood estimation: max θ Ex pdata [log pθ(x)] 1 i=1 log pθ(x(i)) (1) which is equivalent to minimizing the Kullback-Leibler divergence under the empirical distribution, i.e., minimizing DKL(pdata(x)||pθ(x)). This is the setting that is most commonly used in generation of images (e.g., diffusion models [69, 22, 70]) and language (e.g. GPT [58]) where we can easily draw observed data from the distribution. Energy-based training (EB) In other cases, data from the distribution are not always available. Instead, we have access to an unnormalized probability distribution f( ) typically specified as f(x) = exp(r(x)/τ) where r(x) is an energy (or reward) function and τ > 0 is a temperature parameter. In this setting, the objective is to match pθ(x) to f(x)/Z, where Z is the normalization constant of f. This can be done by minimizing the KL divergence [49, 84, 12], The reward function r(x) can be defined either by human preferences or by the physical system from first principles. For example, (a) In aligning large language models, r(x) Generative Marginalization Models can represent human preferences [51, 50]; (b) In molecular/material design, it can specify the proximity of a sample s measured or calculated properties to some functional desiderata [2]; and (c) In modeling the thermodynamic equilibrium ensemble of physical systems, it is the (negative) energy function of a given sample [49, 84, 12, 35]. The training objective in Equation (2) can be optimized using a Monte Carlo estimate of the gradient using the REINFORCE algorithm [83]. A generative model θ allows us to efficiently generate samples approximately from the distribution, which would otherwise be much more expensive via running MCMC with the energy function f( ). Autoregressive models Autoregressive models (ARMs) [3, 38] model a complex high-dimensional distribution p(x) by factorizing it into univariate conditionals using the chain rule: log pϕ(x) = XD d=1 log pϕ (xd | xM)) Persistent block Gibbs sampling Sample x q(x) Sample d U(1, , D), σ U(SD) Lpenalty squared error of Equation (7), for d and σ with x θ,ϕDKL REINFORCE est. with x θ,ϕ θ,ϕDKL + λ θ,ϕLpenalty Update θ and ϕ with gradient end for A.4. Connections between MAMs and GFlow Nets In this section, we identify an interesting connection between generative marginalization models and GFlow Nets. The two type of models are designed with different motivations. GFlow Nets are motivated by learning a policy to generate according to an energy function and MAMs are motivated from any-order generation through learning to perform marginalization. However, under certain conditions, there exists an interesting connection between generative marginalization models and GFlow Nets. In particular, the marginalization self-consistency condition derived from the definition of marginals in Equation (4) has an equivalence to the detailed balance constraint in GFlow Net under the following specific conditions. Observation 1. When the directed acyclic graph (DAG) used for generation in GFlow Net is specified by the following conditions, there is an equivalence between the marginalization self-consistency condition in Equation (7) for MAM and the detailed balance constraint proposed for GFlow Net [4]. In particular, the pθ(xσ(d)|xσ(