# simple_and_effective_masked_diffusion_language_models__94bf636f.pdf Simple and Effective Masked Diffusion Language Models Subham Sekhar Sahoo Cornell Tech, NYC, USA. ssahoo@cs.cornell.edu Marianne Arriola Cornell Tech, NYC, USA. ma2238@cornell.edu Yair Schiff Cornell Tech, NYC, USA. yzs2@cornell.edu Aaron Gokaslan Cornell Tech, NYC, USA. akg87@cs.cornell.edu Edgar Marroquin Cornell Tech, NYC, USA. emm392@cornell.edu Justin T Chiu Cornell Tech, NYC, USA. jtc257@cornell.edu Alexander Rush Cornell Tech, NYC, USA. ar459@cornell.edu Volodymyr Kuleshov Cornell Tech, NYC, USA. kuleshov@cornell.edu While diffusion models excel at generating high-quality images, prior work reports a significant performance gap between diffusion and autoregressive (AR) methods in language modeling. In this work, we show that simple masked discrete diffusion is more performant than previously thought. We apply an effective training recipe that improves the performance of masked diffusion models and derive a simplified, Rao-Blackwellized objective that results in additional improvements. Our objective has a simple form it is a mixture of classical masked language modeling losses and can be used to train encoder-only language models that admit efficient samplers, including ones that can generate arbitrary lengths of text semi-autoregressively like a traditional language model. On language modeling benchmarks, a range of masked diffusion models trained with modern engineering practices achieves a new state-of-the-art among diffusion models, and approaches AR perplexity. We provide the code1, along with a blog post and video tutorial2 on the project page: https://s-sahoo.com/mdlm 1 Introduction Diffusion models excel at producing realistic, high-quality images and have received significant attention as potential tools for generating discrete data, such as text [1, 31, 33], biological sequences [2, 47], and graphs [60, 63]. Unlike autoregressive (AR) approaches, diffusion-based methods are not constrained to generate data sequentially, and therefore have the potential to improve long-term planning, controllable generation, and sampling speed. However, discrete diffusion methods exhibit a performance gap relative to AR models [1, 23, 26, 33], especially in language modeling. The standard measure of language modeling performance is log-likelihood: when controlling for parameter count, prior work reports a sizable log-likelihood gap between AR and diffusion models. In this work, we show that simple masked diffusion language modeling (MDLM) combined with effective training recipes is more performant than previously thought [1, 26]. We develop a wellengineered MDLM implementation that significantly improves discrete diffusion log-likelihood; we 1code: https://github.com/kuleshov-group/mdlm 2tutorial: http://youtu.be/Wj AUX23vgfg 38th Conference on Neural Information Processing Systems (Neur IPS 2024). Figure 1: (Left) Our proposed masked diffusion language model (MDLM) is trained using a weighted average of masked cross entropy losses. (Top Right) In comparison to masked language models (MLM), MDLM s objective correspond to a principled variational lower bound, and supports generation via ancestral sampling. (Bottom Right) Perplexity (PPL) on One Billion Words (LM1B) benchmark. further improve likelihood using a simple substitution-based parameterization of the reverse diffusion process that enables deriving a Rao-Blackwellized continuous-time variational lower bound (ELBO) with improved tightness [49]. Interestingly, our objective has a simple form: it is a weighted average of masked language modeling (MLM) losses [15], and can be used to endow BERT-style, encoder-only models with principled generation capabilities. We complement this framework with efficient samplers including ones that can generate semi-autoregressively like a typical language model. Our masked diffusion models achieve a new state-of-the-art among diffusion models on language modeling benchmarks and approach the perplexity of AR models within 15-25%. Surprisingly, simple engineering choices significantly improve performance in both our models and simple baselines that were previously thought to perform poorly. Our framework also extends to non-language domains, including biological sequence modeling. We pre-train DNA sequence models and observe similar or higher downstream performance compared to classical BERT-style training, while also introducing generative capabilities that classical masked DNA language models lack. Contributions We describe (1) a simple masked diffusion language modeling (MDLM) framework with a well-engineered implementation that outperforms all existing diffusion models across language modeling benchmarks (LM1B [8], OWT [18], DNA [12]), and that significantly improves the performance of existing baselines [1, 26]. Our MDLM framework implements (2a) a substitution-based parameterization (SUBS) of the reverse unmasking diffusion process; SUBS allows us to derive (2b) a simple, continuous-time, Rao-Blackwellized objective that improves tightness and variance of the ELBO, further increasing performance. We complement MDLM with (3) fast samplers that support semi-autoregressive (SAR) generation and outperform previous SAR models. 2 Background 2.1 Diffusion Models Diffusion models are trained to iteratively undo a forward corruption process q that takes clean data x drawn from the data distribution q(x) and defines latent variables zt for t [0,1] that represent progressively noisy versions of x [27, 54, 56, 66, 48, 19]. The standard forward process for continuous x is where ϵ N(0,I) and (αt)t [0,1] is a noise schedule, monotonically decreasing in t. The parameterized reverse diffusion model pθ over x and zt is trained to maximize a variational lower bound on loglikelihood (ELBO). Given a number of discretization steps T, defining s(i)=(i 1)/T and t(i)=i/T, and using DKL[ ] to denote the Kullback Leibler divergence, the Negative ELBO (NELBO) equals [54]: logpθ(x|zt(0)) | {z } Lrecons i=1 DKL[q(zs(i)|zt(i),x) pθ(zs(i)|zt(i))] | {z } Ldiffusion +DKL[q(zt(T )|x) pθ(zt(T ))] | {z } Lprior For brevity, we drop i from t(i) and s(i) below; in general, s will denote the time step before t. 2.2 Discrete Diffusion Models Applications of diffusion modeling to discrete data can be broken into two broad categories. First are works that embed discrete structures in continuous space and then perform the Gaussian diffusion defined above on these continuous representations [9, 16, 23, 24, 30, 34, 57]. More related to our method are works that define a diffusion process directly on discrete structures. D3PM [1] introduces a framework with a Markov forward process q(zt|zt 1)=Cat(zt;Qtzt 1) defined by the multiplication of matrices Qt over T discrete time steps. This process induces marginals q(zt|x)=Cat(zt; Qtx)=Cat(zt;Qt Qt 1 Q1x) (3) that represent the discrete-state form of (1). Extending this formalism to continuous time (as in (1)) relies on continuous time Markov chain (CTMC) theory [5]. The CTMC framework in turns leads to generalizations of the score matching perspective on diffusion modeling [55] to discrete data [33, 59]. Notably, SEDD [33] connects score-based approaches with ELBO maximization, enabling performant likelihood-based training of score-based models. 3 Simple Masked Diffusion Models While previous work on discrete diffusion supports general forward processes (e.g., general Qt in D3PM), absorbing state (i.e., masking) diffusion consistently achieves the best performance [1, 33]. In this work, instead of supporting general noise processes, we focus on masking and derive tight Rao-Blackwellized objectives that outperform general approaches and do not require CTMC theory. In this section, we first define the diffusion process for a categorical random variable. Later in Sec. 3.5, we extend this process to sequences containing multiple such categorical variables. We denote our overall approach as Masked Diffusion Language Models (MDLM). Notation. We denote scalar discrete random variables with K categories as one-hot column vectors and define V {x {0,1}K : PK i=1xi = 1} as the set of all such vectors. Define Cat( ;π) as the categorical distribution over K classes with probabilities given by π K, where K denotes the K-simplex. We also assume that the K-th category corresponds to a special [MASK] token and let m V be the one-hot vector for this mask, i.e., m K = 1. Additionally, let 1 = {1}K and a,b and a b respectively denote the dot and Hadamard products between two vectors a and b. 3.1 Interpolating Discrete Diffusion We restrict our attention to forward processes q that interpolate between clean data x V and a target distribution Cat(.;π), forming a direct extension of Gaussian diffusion in (1). Let q define a sequence of increasingly noisy latent variables zt V, where the time step t runs from t=0 (least noisy) to t=1 (most noisy). The marginal of zt conditioned on x at time t is q(zt|x)=Cat(zt;αtx+(1 αt)π), (4) where αt [0,1] is a strictly decreasing function in t, with α0 1 and α1 0; see Suppl. E.1 for details. This implies transition probabilities q(zt|zs) = Cat(zt;αt|szs +(1 αt|s)π), where αt|s = αt/αs. This indicates that during each diffusion step from s t, a fraction (1 αt|s) of the probability mass is transferred to the prior distribution π. The reverse posterior is given as (see Suppl. 16 for details): q(zs|zt,x)=Cat zs;[αt|szt+(1 αt|s)1π zt] [αsx+(1 αs)π] αtz t x+(1 αt)z t π While (4) and (5) represent a special case of the more general diffusion processes proposed in D3PM [1], we show below that they yield a simplified variational lower bound objective and admit straightforward continuous time extensions. 3.2 Masked Diffusion Next, we focus on masking processes and derive a simple Rao-Blackwellized objective for this choice of q. This objective incurs lower variance during training and improves tightness. 3.2.1 Forward Masking Process In masked (i.e., absorbing state) diffusion, we set π=m. At each noising step, t, the input x transitions to a masked state m with some probability. If an input transitions to m at any time t , it will remain in this state for all t>t :q(zt |zt =m)=Cat(zt;m). At time T, all inputs are masked with probability 1. The marginal of the forward process (4) is given by q(zt|x) = Cat(zt;αtx + (1 αt)m). Using properties of the masking process, the posterior q(zs|zt,x) simplifies (5); see Suppl. A.2: q(zs|zt,x)= ( Cat(zs;zt) zt =m, Cat zs; (1 αs)m+(αs αt)x 3.2.2 Reverse Unmasking Process The reverse process inverts the noise process defined by q. We consider both a finite number of steps T, as well as a continuous time model corresponding to T . We begin with the discrete-time case for which the generative model is expressed as pθ(x)= R zpθ(z1)pθ(x|z0)QT i=1pθ(zs|zt)dz0:1. The optimal form for pθ(zs|zt) matches the true posterior in (6): this follows immediately from the definition of the diffusion objective in (2), which is a sum of terms of the form DKL(q(zs|zt,x) pθ(zs|zt)). However, (6) is conditioned on x, which we do not know. Therefore, we introduce a model xθ(zt,t):V [0,1] K that approximates x with a neural network. We can also omit explicit dependence of xθ on time t, which simplifies sampling, yielding a 2x inference speed-up (see Suppl. E.2). 3.2.3 SUBS Parameterization The specific parameterization for pθ(zs|zt) that we use is pθ(zs|zt)=q(zs|zt,x=xθ(zt,t))= ( Cat(zs;zt), zt =m, Cat zs; (1 αs)m+(αs αt)xθ(zt,t) . zt =m, (7) Furthermore, we induce 2 key properties of the absorbing state diffusion process into our denoising model, xθ(zt,t): an unmasked token remains unchanged during reverse diffusion, and the clean input is never masked. We implement these as substitutions to the output of xθ(zt,t), hence we call our parameterization SUBS. Zero Masking Probabilities First, notice that by definition, x,m =0. For this reason, we design the denoising network such that xθ(zt,t),m = 0, i.e., we substitute the logit index corresponding to the [MASK] token with . Carry-Over Unmasking Second, if zt is unmasked, then we desire xθ(zt,t)=zt, i.e., unmasked latents are carried over . We accomplish this by substituting the output of our network to simply copy unmasked inputs. In Suppl. B.1, we show that Zero Masking Probabilities property simplifies the D3PM s NELBO (39) to (41), and Carry-Over Unmasking futher simplifies (41) to (43) whose continuous time equivalent is the simplified NELBO (10). Table 8 shows that each simplification leads to an improved likelihood. 3.3 Rao-Blackwellized Likelihood Bounds Recall from (2) that the diffusion traning objective has the form Lrecons +Ldiffusion +Lprior. For the simplified reverse process in (7), the discrete-time diffusion loss for finite T simplifies to (Suppl. B.1.3): Ldiffusion = i=1 Eq[DKL(q(zs(i)|zt(i),x) pθ(zs(i)|zt(i)))]= αt(i) αs(i) 1 αt(i) log xθ(zt(i)),x Note that this objective is simpler and more well-behaved than the expression one would obtain for DKL(q(zs|zt,x) pθ(zs|zt)) under the parameterization induced by using pθ(zs|zt) = q(zs|zt,x = xθ(zt,t)) from (5), which is similar to what is used by D3PM [1] (see Suppl. A.2.4): αs αt 1 αt logαt xθ(zt,t),m +(1 αt) (1 αt) xθ(zt,t),x + 1 αs 1 αt log (1 αs)(αt xθ(zt,t),m +(1 αt)) (1 αt)(αs xθ(zt,t),m +(1 αs)) We refer to the process of obtaining (8) in lieu of (9) as a form of Rao-Blackwellization. Specifically, we analytically compute expectations such as xθ(zt,t),m =0 in order to simplify objective (9) to obtain (8). Without analytical simplifications, a model must learn θ such that xθ(zt,t),m =0 holds. Unlike in regular Rao-Blackwellization, simplifications are possible because of modeling choices for xθ(zt,t) (zero masking probabilities and carry-over unmasking). In that sense, our approach has similarities to graphical modeling, where incorporating conditional independencies into pθ sets certain log-likelihood terms to zero. However, our approach also empirically helps reduce variance, hence we refer to it as Rao-Blackwellization, somewhat abusing the usual terminology. 3.4 Continuous-Time Likelihood Bounds Previous works have shown empirically and mathematically that increasing the number of steps T yields a tighter approximation to the ELBO [29]. Following a similar argument, we form an continuous extension of (8) by taking T (see Suppl. B.2), which yields the following NELBO, L NELBO: L NELBO =Eq α t 1 αt log xθ(zt,t),x dt (10) Invariance to the noise schedule The function αt is invertible due to the monotonicity assumption in Sec. 3.1, and so we can perform the following change of variables in (10): γ log(1 αt). Thus, the diffusion loss can be equivalently expressed as L NELBO = Eq R γ=0 γ= log xθ(zγ,γ),x dγ; see Suppl. E.1.1 for details. This new formulation demonstrates that the diffusion loss is invariant to the functional form of αt, which we verify empirically in Suppl. E.1. 3.5 Masked Diffusion Language Models Next, we apply masked diffusion to language modeling over sequences x1:L of L tokens, with xℓdenoting the ℓ-th token. We make the assumption that the forward noising process is applied independently across a sequence and that, conditioned on a sequence of latents z1:L t , the denoising process factorizes independently across tokens, i.e., pθ(z1:L s | z1:L t ) = QL ℓ=1pθ(zℓ s | z1:L t ). Thus, we use a single model to compute xℓ θ(z1:L t ,t) for each ℓfrom a masked sequence zt, optimizing: L NELBO =Eq ℓ log xℓ θ(z1:L t ,t),xℓ dt (11) Interestingly, our objective has a simple form: it is the weighted average of masked language modeling (MLM) losses [15]. Thus our work establishes a connection between generative diffusion models and encoder-only BERT models. Our objective enables principled selection of a (randomized) masking rate, and also endows BERT-style models with principled generation capabilities; see Sec. 6. The full training algorithm is provided in Suppl. B.3. Note: Although (11) imposes a loss on all tokens, unmasked tokens don t contribute to the loss, as they are copied over by the denoising network due to carry-over unmasking (Sec. 3.2.3), effectively reducing log xℓ θ(z1:L t ,t),xℓ to zero. 3.5.1 Training Considerations for Masked Diffusion One of the key contributions of our work is a well-engineered implementation of masked diffusion models. Our experiments demonstrate that these improvements greatly boost performance even for methods previously thought to perform poorly, e.g., Austin et al. [1]. Below we briefly summarize these implementation details. First, we find that tokenization is critical to performance. Small vocabularies, such as the 8k vocabulary in Austin et al. [1], result in longer-range dependencies that decrease the performance of both diffusion and AR models. Additionally, by focusing on masked diffusion, we are able to provide a numerically stable implementation of the objective function. Namely, since previous formulations of discrete diffusion were constructed to accommodate a wide range of limiting distributions [1], the objective was implemented by materializing the full transition matrices Qt and posterior probabilities. In contrast, we evaluate DKL[q(zs |zt,x)||pθ(zs |zt)] by examining only the masked token indices rather than comparing the full true and approximate posterior distributions. Furthermore, we modernize the architecture for the denoising network relative to D3PM [1]. In lieu of the T5 architecture used in D3PM, we use the diffusion transformer (Di T) introduced in Peebles & Xie [42], which integrates time step conditioning into a standard encoder-only transformer [62] and uses rotary positional embeddings [58]. In addition, we implement a low-discrepancy sampler that reduces the variance of the ELBO, similar to Kingma et al. [29] and draws correlated samples ti rather than performing i.i.d. sampling. 4 Inference and Sampling in Masked Diffusion Language Models 4.1 Efficient Ancestral Sampling To generate a sequence of length L, the reverse diffusion process starts with the sequence z1:L t=1 where zℓ t=1 = m, for all ℓ {1,...,L}. Then the subsequent latents, z1:L t are generated by discretizing the reverse diffusion process with some finite T. Given z1:L t , we construct z1:L s by sampling each token zℓ s independently from the distribution pθ(zℓ s|z1:L t ) given in (7). Note that in the reverse process, unmasked tokens remain unchanged. Thus, if no new tokens in z1:L s become unmasked (which can occur often in early denoising stages for large T), then z1:L s = z1:L t . Additionally if the denoising model, xθ(z1:L t ) is not conditioned on time, then we can simply draw a new sample from pθ(z1:L s 1/T |z1:L s ) using the previously computed and cached value xθ(z1:L t ). This means we have effectively skipped over the time step s, saving a function call to the denoising network. Note that SEDD [33] does not support this caching because the denoising network models time-dependent rates, which requires conditioning on time. 4.2 Semi-Autoregressive Masked Diffusion Language Models Our method also admits an effective semi-autoregressive (SAR) decoding method that allows the model to generate sequences of arbitrary length [24, 52, 53]. Let x1:L represent the output from sampling a sequence of L tokens using the reverse diffusion process described above. To generate additional L