# diffusion_bridges_vector_quantized_variational_autoencoders__667cc1f3.pdf Diffusion bridges vector quantized variational autoencoders Max Cohen 1 2 Guillaume Quispe 3 Sylvain Le Corff 1 Charles Ollion 3 Eric Moulines 3 Vector Quantized-Variational Auto Encoders (VQVAE) are generative models based on discrete latent representations of the data, where inputs are mapped to a finite set of learned embeddings. To generate new samples, an autoregressive prior distribution over the discrete states must be trained separately. This prior is generally very complex and leads to slow generation. In this work, we propose a new model to train the prior and the encoder/decoder networks simultaneously. We build a diffusion bridge between a continuous coded vector and a non-informative prior distribution. The latent discrete states are then given as random functions of these continuous vectors. We show that our model is competitive with the autoregressive prior on the mini-Imagenet and CIFAR dataset and is efficient in both optimization and sampling. Our framework also extends the standard VQ-VAE and enables end-to-end training. 1. Introduction Variational Auto Encoders (VAE) have emerged as important generative models based on latent representations of the data. While the latent states are usually continuous vectors, Vector Quantized Variational Auto Encoders (VQ-VAE) have demonstrated the usefulness of discrete latent spaces and have been successfully applied in image and speech generation (Oord et al., 2017; Esser et al., 2021; Ramesh et al., 2021). In a VQ-VAE, the distribution of the inputs is assumed to depend on a hidden discrete state. Large scale image 1Samovar, T el ecom Sud Paris, D epartement CITI, Institut Polytechnique de Paris, Palaiseau, France. 2Oze Energies, Charenton Le-Pont, France. 3Centre de Math ematiques Appliqu ees, Ecole polytechnique, Institut Polytechnique de Paris, Palaiseau, France. Correspondence to: Max Cohen . Proceedings of the 39 th International Conference on Machine Learning, Baltimore, Maryland, USA, PMLR 162, 2022. Copyright 2022 by the author(s). generation VQ-VAEs use for instance multiple discrete latent states, typically organized as 2-dimensional lattices. In the original VQ-VAE, the authors propose a variational approach to approximate the posterior distribution of the discrete states given the observations. The variational distribution takes as input the observation, which is passed through an encoder. The discrete latent variable is then computed by a nearest neighbour procedure that maps the encoded vector to the nearest discrete embedding. It has been argued that the success of VQ-VAEs lies in the fact that they do not suffer from the usual posterior collapse of VAEs (Oord et al., 2017). However, the implementation of VQ-VAE involves many practical tricks and still suffers from several limitations. First, the quantization step leads the authors to propose a rough approximation of the gradient of the loss function by copying gradients from the decoder input to the encoder output. Second, the prior distribution of the discrete variables is initially assumed to be uniform when training the VQ-VAE. In a second training step, high-dimensional autoregressive models such as Pixel CNN (van den Oord et al., 2016; Salimans et al., 2017; Chen et al., 2018) and Wave Net (Oord et al., 2016) are estimated to obtain a complex prior distribution. Joint training of the prior and the VQ-VAE is a challenging task for which no satisfactory solutions exist yet. Our work addresses both problems by introducing a new mathematical framework that extends and generalizes the standard VQ-VAE. Our method enables end-to-end training and, in particular, bypasses the separate training of an autoregressive prior. An autoregressive pixel CNN prior model has several drawbacks, which are the same in the pixel space or in the latent space. The data is assumed to have a fixed sequential order, which forces the generation to start at a certain point, typically in the upper left corner, and span the image or the 2dimensional latent lattice in an arbitrary way. At each step, a new latent variable is sampled using the previously sampled pixels or latent variables. Inference may then accumulate prediction errors, while training provides ground truth at each step. The runtime process, which depends mainly on the number of network evaluations, is sequential and depends on the size of the image or the 2-dimensional latent lattice, which can become very large for high-dimensional objects. Diffusion bridges for VQVAE The influence of the prior is further explored in (Razavi et al., 2019), where VQ-VAE is used to sample images on a larger scale, using two layers of discrete latent variables, and (Willetts et al., 2021) use hierarchical discrete VAEs with numerous layers of latent variables. Other works such as (Esser et al., 2021; Ramesh et al., 2021) have used Transformers to autoregressively model a sequence of latent variables: while these works benefit from the recent advances of Transformers for large language models, their autoregressive process still suffers from the same drawbacks as pixel CNN-like priors. The main claim of our paper is that using diffusions in a continuous space, Rd N in our setting, is a very efficient way to learn complex discrete distributions, with support on a large space (here with cardinality KN). We only require an embedded space, an uninformative target distribution (here a Gaussian law), and use a continuous bridge process to learn the discrete target distribution. In that direction, our contribution is inspired by the literature but also significantly different. Our procedure departs from the diffusion probabilistic model approach of (Ho et al., 2020), which highlights the role of bridge processes in denoising continuous target laws, and from (Hoogeboom et al., 2021), where multinomial diffusions are used to noise and denoise but prevent the use of the expressiveness of continuous bridges, and also do not scale well with K as remarked by its authors. Although we target a discrete distribution, our approach does not suffer from this limitation. Our contributions are summarized as follows. We propose a new mathematical framework for VQVAEs. We introduce a two-stage prior distribution. Following the diffusion probabilistic model approach of (Ho et al., 2020), we consider first a continuous latent vector parameterized as a Markov chain. The discrete latent states are defined as random functions of this Markov chain. The transition kernels of the continuous latent variables are trained using diffusion bridges to gradually produce samples that match the data. To our best knowledge, this is the first probabilistic generative model to use denoising diffusion in discrete latent space. This framework allows for end-to-end training of VQ-VAE. We focus on VQ-VAE as our framework enables simultaneous training of all components of those popular discrete models which is not straightforward. However, our methodology is more general and allows the use of continuous embeddings and diffusion bridges to sample form any discrete laws. We present our method on a toy dataset and then compare its efficiency to the pixel CNN prior of the original VQ-VAE on the mini Imagenet dataset. Figure 1 describes the complete architecture of our model. Figure 1. Our proposed architecture, for a prior based on a Ornstein-Uhlenbeck bridge. The top pathway from input image to z0 e, to z0 q, to reconstructed image resembles the original VQ-VAE model. The vertical pathway from (z0 e, z0 q) to (z T e , z T q ) and backwards is based on a denoising diffusion process. See Section 3.2 and Algorithm 2 for the corresponding sampling procedure. 2. Related Works Diffusion Probabilistic Models. A promising class of models that depart from autoregressive models are Diffusion Probabilistic Models (Sohl-Dickstein et al., 2015; Ho et al., 2020) and closely related Score-Matching Generative Models (Song & Ermon, 2019; De Bortoli et al., 2021). The general idea is to apply a corrupting Markovian process on the data through T corrupting steps and learn a neural network that gradually denoises or reconstructs the original samples from the noisy data. For example, when sampling images, an initial sample is drawn from an uninformative distribution and reconstructed iteratively using the trained Markov kernel. This process is applied to all pixels simultaneously, so no fixed order is required and the sampling time does not depend on sequential predictions that depend on the number of pixels, but on the number of steps T. While this number of steps can be large (T = 1000 is typical), simple improvements enable to reduce it dramatically and obtain 50 speedups (Song et al., 2021). These properties have led diffusion probability models to receive much attention in the context of continuous input modelling. From Continuous to Discrete Generative denoising. In (Hoogeboom et al., 2021), the authors propose multinomial diffusion to gradually add categorical noise to discrete samples for which the generative denoising process is learned. Unlike alternatives such as normalizing flows, the diffusion proposed by the authors for discrete variables does not require gradient approximations because the parameter of the diffusion is fixed. Such diffusion models are optimized using variational infer- Diffusion bridges for VQVAE ence to learn the denoising process, i.e., the bridge that aims at inverting the multinomial diffusion. In (Hoogeboom et al., 2021), the authors propose a variational distribution based on bridge sampling. In (Austin et al., 2021), the authors improve the idea by modifying the transition matrices of the corruption scheme with several tricks. The main one is the addition of absorbing states in the corruption scheme by replacing a discrete value with a MASK class, inspired by recent Masked Language Models like BERT. In this way, the corrupted dimensions can be distinguished from the original ones instead of being uniformly sampled. One drawback of their approach, mentioned by the authors, is that the transition matrix does not scale well for a large number of embedding vectors, which is typically the case in VQ-VAE. Compared to discrete generative denoising, our approach takes advantage of the fact that the discrete distribution depends solely on a continuous distribution in VQ-VAE. We derive a novel model based on continuous-discrete diffusion that we believe is simpler and more scalable than the models mentioned in this section. From Data to Latent Generative denoising. Instead of modelling the data directly, (Vahdat et al., 2021) propose to perform score matching in a latent space. The authors propose a complete generative model and are able to train the encoder/decoder and score matching end-to-end. Their method also achieve excellent visual patterns and results but relies on a number of optimization heuristics necessary for stable training. In (Mittal et al., 2021), the authors have also applied such an idea in a generative music model. Instead of working in a continuous latent space, our method is specifically designed for a discrete latent space as in VQVAEs. Using Generative denoising in discrete latent space. In the model proposed by (Gu et al., 2021), the autoregressive prior is replaced by a discrete generative denoising process, which is perhaps closer to our idea. However, the authors focus more on a text-image synthesis task where the generative denoising model is traine based on an input text: it generates a set of discrete visual tokens given a sequence of text tokens. They also consider the VQ-VAE as a trained model and focus only on the generation of latent variables. This work focuses instead on deriving a full generative model with a sound probabilistic interpretation that allows it to be trained end-to-end. 3. Diffusion bridges VQ-VAE 3.1. Model and loss function Assume that the distribution of the input x Rm depends on a hidden discrete state zq E = {e1, . . . , e K} with ek Rd for all 1 k K. Let pθ be the joint probability density of (zq, x) (zq, x) 7 pθ(zq, x) = pθ(zq)pθ(x|zq) , where θ Rp are unknown parameters. Consider first an encoding function fφ and write ze(x) = fφ(x) the encoded data. In the original VQ-VAE, the authors proposed the following variational distribution to approximate pθ(zq|x): qφ(zq|x) = δek x(zq) , where δ is the Dirac mass and k x = argmin1 k K { ze(x) ek 2} , where φ Rr are all the variational parameters. In this paper, we introduce a diffusion-based generative VQ-VAE. This model allows to propose a VAE approach with an efficient joint training of the prior and the variational approximation. Assume that zq is a sequence, i.e. zq = z0:T q , where the superscript refers to the time in the diffusion process and for all sequences (au)u 0 and all 0 s t, as:t stands for (as, . . . , at). Consider the following joint probability distribution pθ(z0:T q , x) = pzq θ (z0:T q )px θ(x|z0 q) . The latent discrete state z0 q used as input in the decoder is the final state of the chain (z T q , . . . , z0 q). We further assume that pzq θ (z0:T q ) is the marginal distribution of pθ(z0:T q , z0:T e ) = pze θ,T (z T e )pzq θ,T (z T q |z T e ) t=0 pze θ,t|t+1(zt e|zt+1 e )pzq θ,t(zt q|zt e) . In this setting, {zt e}0 t T are continuous latent states in Rd N and conditionally on {zt e}0 t T the {zt q}0 t T are independent with discrete distribution with support EN. This means that we model jointly N latent states as this is useful for many applications such as image generation. The continuous latent state is assumed to be a Markov chain and at each time step t the discrete variable zt q is a random function of the corresponding zt e. Although the continuous states are modeled as a Markov chain, the discrete variables arising therefrom have a more complex statistical structure (and in particular are not Markovian). The prior distribution of z T e is assumed to be uninformative and this is the sequence of denoising transition densities {pze θ,t|t+1}0 t T 1 which provides the final latent state z0 e which is mapped to the embedding space and used in the decoder, i.e. the conditional law of the data given the latent states. The final discrete z0 q only depends the continuous latent variable z0 e, similar to the dependency between zq and ze in the original VQ-VAE. Diffusion bridges for VQVAE Since the conditional law pθ(z0:T q , z0:T e |x) is not available explicitly, this work focuses on variational approaches to provide an approximation. Then, consider the following variational family: qφ(z0:T q , z0:T e |x) = δze(x)(z0 e)qzq φ,0(z0 q|z0 e) n qze φ,t|t 1(zt e|zt 1 e )qzq φ,t(zt q|zt e) o . The family {qze φ,t|t 1}1 t T of forward noising transition densities are chosen to be the transition densities of a continuous-time process (Zt)t 0 with Z0 = ze(x). Sampling the diffusion bridge ( Zt)t 0, i.e. the law of the process (Zt)t 0 conditioned on Z0 = ze(x) and ZT = z T e is a challenging problem for general diffusions, see for instance (Beskos et al., 2008; Lin et al., 2010; Bladt et al., 2016). By the Markov property, the marginal density at time t of this conditioned process is given by: qze φ,t|0,T (zt e|z0 e, z T e ) = qze φ,t|0(zt e|z0 e)qze φ,T |t(z T e |zt e) qze φ,T |0(z Te |z0e) . (1) The Evidence Lower BOund (ELBO) is then defined, for all (θ, φ), as L(θ, φ) = Eqφ log pθ(z0:T q , z0:T e , x) qφ(z0:T q , z0:T e |x) where Eqφ is the expectation under qφ(z0:T q , z0:T e |x). Lemma 3.1. For all (θ, φ), the ELBO L(θ, φ) is: L(θ, φ) = Eqφ log px θ(x|z0 q) + t=0 Lt(θ, φ) log pzq θ,t(zt q|zt e) qzq φ,t(ztq|zte) where, for 1 t T 1, L0(θ, φ) = Eqφ h log pze θ,0|1(z0 e|z1 e) i , Lt(θ, φ) = Eqφ log pze θ,t 1|t(zt 1 e |zt e) qze φ,t 1|0,t(zt 1 e |z0e, zte) LT (θ, φ) = Eqφ log pze θ,T (z T e ) qze φ,T |0(z Te |z0e) Proof. The proof is standard and postponed to Appendix A. The three terms of the objective function can be interpreted as follows: L(θ, φ) = Lrec(θ, φ) + t=0 Lt(θ, φ) + t=0 Lreg t (θ, φ) with Lrec = Eqφ[log px θ(x|z0 q)] a reconstruction term, Lt the diffusion term, and an extra term Lreg t = Eqφ log pzq θ,t(zt q|zt e) qzq φ,t(ztq|zte) which may be seen as a regularization term as discussed in next sections. 3.2. Application to Ornstein-Uhlenbeck processes Consider for instance the following Stochastic Differential Equation (SDE) to add noise to the normalized inputs: d Zt = ϑ(Zt z )dt + ηd Wt , (3) where ϑ, η > 0, z Rd N is the target state at the end of the noising process and {Wt}0 t T is a standard Brownian motion in Rd N. We can define the variational density by integrating this SDE along small step-sizes. Let δt be the time step between the two consecutive latent variables zt 1 e and zt e. In this setting, qze φ,t|t 1(zt e|zt 1 e ) is a Gaussian probability density function with mean z + (zt 1 e z )e ϑδt in Rd N and covariance matrix (2ϑ) 1η2(1 e 2ϑδt)Id N, where for all n 1, In is the identity matrix with size n n. Asymptotically the process is a Gaussian with mean z and variance η2(2ϑ) 1Id N. The denoising process amounts then to sampling from the bridge associated with the SDE, i.e. sampling zt 1 e given z0 e and zt e. The law of this bridge is explicit for the Ornstein Uhlenbeck diffusion (3). Using (1), qze φ,s|0,t(zs e|zt e, z0 e) qze φ,s|0(zt 1 e |z0 e)qze φ,t|s(zt e|zs e) , where 0 s t, so that qze φ,t 1|0,t(zt 1 e |zt e, z0 e) is a Gaussian probability density function with mean µφ,t 1|0,t(z0 e, zt e) = βt 1 αt z + αt 1(z0 e z ) αt zt e (1 αt)z and covariance matrix σ2 φ,t 1|0,t = η2 1 αt βt Id N , where βt = 1 exp( 2ϑδt), αt = 1 βt and αt = Qt s=1 αs. Note that the bridge sampler proposed in (Ho et al., 2020) is a specific case of this setting with η = 2, z = 0 and ϑ = 1. Choice of denoising model pθ. Following (Ho et al., 2020), we propose a Gaussian distribution for Diffusion bridges for VQVAE pze θ,t 1|t(zt 1 e |zt e) with mean µθ,t 1|t(zt e, t) and variance σ2 θ,t 1|t Id N. In the following, we choose σ2 θ,t 1|t = η2 so that the term Lt of Lemma 3.1 writes 2σ2 θ,t 1|t Lt(θ, φ) = Eqφ h µθ,t 1|t(zt e, t) µφ,t 1|0,t(z0 e, zt e) 2 2 In addition, under qφ, zt e has the same distribution as ht e(z0 e, εt) = z + αt(z0 e z ) + 2ϑ(1 αt)εt , where εt N(0, Id N). Then, for instance in the case z = 0, µφ,t 1|0,t can be reparameterised as follows: µφ,t 1|0,t(z0 e, zt e) = ht e(z0 e, εt) 2ϑ(1 αt)βtεt We therefore propose to use µθ,t 1|t(zt e, t) = 2ϑ(1 αt)βtεθ(zt e, t) which yields Lt(θ, φ) = βt 2αt(1 αt 1)E h εt εθ(ht e(z0 e, εt), t) 2 2 Several choices can be proposed to model the function εθ. The deep learning architectures considered in the numerical experiments are discussed in Appendix D and E. Similarly to (Ho et al., 2020), we use a stochastic version of our loss function: sample t uniformly in {0, . . . , T}, and consider Lt(θ, φ) instead of the full sum over all t. The final training algorithm is described in Algorithm 1 and the sampling procedure in Algorithm 2. Connections with the VQ-VAE loss function. In the special case where T = 0, our loss function can be reduced to a standard VQ-VAE loss function. In that case, write zq = z0 q and ze = z0 e, the ELBO then becomes: L(θ, φ) = Eqφ [log px θ(x|zq)] + Eqφ log pzq θ (zq|ze) qzq φ (zq|ze) Then, if we assume that pzq θ (zq|ze) = Softmax{ ze ek 2 2}1 k K and that qzq φ (zq|ze) is as in (Oord et al., 2017), i.e. a Dirac mass at bzq = argmin1 k K ze ek 2 2, up to an additive constant, this yields the following random estimation of Eqφ[log pzq θ (zq|ze)/qzq φ (zq|ze)], b Lreg zq (θ, φ) = ze bzq 2 k=1 exp { ze ek 2} The first term of this loss is the loss proposed in (Oord et al., 2017) which is then split into two parts using the stop gradient operator. The last term is simply the additional normalizing term of pzq θ (zq|ze). Connecting diffusion and discretisation. Similar to the VQ-VAE case above, it is possible to consider only the term Lreg 0 (θ, φ) in the case T > 0. However, our framework allows for much flexible parameterisation of pzq θ,t(zt q|zt e) and qzq φ,t(zt q|zt e). For instance, the Gumbel-Softmax trick provides an efficient and differentiable parameterisation. A sample zt q pzq θ,t(zt q|zt e) (resp. zt q qzq φ,t(zt q|zt e)) can be obtained by sampling with probabilities proportional to {exp{( ze ek 2 2 + Gk)/τt}}1 k K (resp. {exp{( ze ek 2 2 + Gk)/τ}}1 k K), where {(Gk, Gk)}1 k K are i.i.d. with distribution Gumbel(0, 1), τ > 0, and {τt}0 t T are positive time-dependent scaling parameters. In practice, the third part of the objective function can be computed efficiently, by using a stochastic version of the ELBO, computing a single Lreg t (θ, φ) instead of the sum (we use the same t for both parts of the ELBO). The term reduces to: Lreg t (θ, φ) = KL(qφ(zt q|zt e) pθ(zt q|zt e)) . (5) This terms connects the diffusion and quantization parts as it creates a gradient pathway through a step t of the diffusion process, acting as a regularisation on the codebooks and zt e. Intuitively, maximizing Lreg t (θ, φ) accounts for pushing codebooks and zt e together or apart depending on the choice of τ, τt. The final end-to-end training algorithm is described in Algorithm 1, and further considerations are provided in Appendix C. 4. Experiments 4.1. Toy Experiment In order to understand the proposed denoising procedure for VQ-VAE, consider a simple toy setting in which there is no encoder nor decoder, and the codebooks {ej}0 j K 1 are fixed. In this case, with d = 2 and N = 5, x = z0 e R2 5. We choose K = 8 and the codebooks ej = µj R2, 0 j K 1, are fixed centers at regular angular intervals in R2 and shown in Figure 2; the latent states (zt q)1 t T Diffusion bridges for VQVAE Algorithm 1 Training procedure Compute z0 e = fφ(x) Sample ˆzq 0 qφ(z0 q|z0 e) Compute ˆLrec(θ, φ) = log px θ(x| ˆzq 0) Sample t Uniform({0, . . . , T}) Sample εt N(0, Id N) Sample zt e qφ(zt e|z0 e) (using εt) Compute ˆLt(θ, φ) from εθ(zt e, t) and εt using (4) Compute ˆLreg t (θ, φ) from zt e (see text) ˆL(θ, φ) = ˆLrec(θ, φ) + ˆLt(θ, φ) + ˆLreg t (θ, φ) Perform SGD step on ˆL(θ, φ) until convergence Algorithm 2 Sampling procedure (for z = 0) Sample z T e N(0, (2ϑ) 1η2Id N) for t = T to 1 do Set zt 1 e = α 1/2 t zt e q η2 2ϑ(1 αt)βtεθ(zt e, t) end for Sample z0 q pzq θ,0(z0 q|z0 e) {quantization} Sample x px θ(x|z0 q) {decoder} lie in {e0, . . . , e7}5. Data generation proceeds as follows. First, sample a sequence of (q1, . . . , q5) in {0, . . . , 7}: q1 has a uniform distribution, and, for s {0, 1, 2, 3}, qs+1 = qs+bs mod 8, where bs are independent Bernoulli samples with parameter 1/2 taking values in { 1, 1}. Conditionally on (q1, . . . , q5), x is a Gaussian random vector with mean (eq1, . . . , eq5) and variance I2 5. Figure 2. Toy dataset, with K = 8 centroids, and two samples x = (x1, x2, x3, x4, x5) in R2 5 each displayed as 5 points in R2 (blue and red points), corresponding to the discrete sequences (red) (6, 5, 4, 3, 2) and (blue) (7, 0, 1, 0, 1). We train our bridge procedure with T = 50 timesteps, ϑ = 2, η = 0.1, other architecture details and the neural network εθ(zt e, t) are described in Appendix E. Forward noise process and denoising using εθ(zt e, t) are showcased in Figure 3, and more illustrations and experiments can be found in Appendix E. Figure 3. (Left) Forward noise process for one sample. First, one data is drawn (z0 e(x) = x in the toy example) and then {zt e}1 t T are sampled under qφ and displayed. (Right) Reverse process for one sample z T e N(0, (2ϑ) 1η2Id N). As expected, the last sample z0 e reaches the neighborhood of 5 codebooks. End-to-end training. Contrary to VQ-VAE procedures in which the encoder/decoder/codebooks are trained separately from the prior, we can train the bridge prior alongside the codebooks. Consider a new setup, in which the K = 8 codebooks are randomly initialized and considered as parameters of our model (they are no longer fixed to the centers of the data generation process µj). The first part of our loss function, in conjunction with the Gumbel-Softmax trick makes it possible to train all the parameters of the model end-to-end. Details of the procedure and results are shown in Appendix E. 4.2. Image Synthesis In this section, we focus on image synthesis using CIFAR10 and mini Image Net datasets. The goal is to evaluate the efficiency and properties of our model compared to the original Pixel CNN. Note that for fair comparisons, the encoder, decoder and codebooks are pretrained and fixed for all models, only the prior is trained and evaluated here. As our goal is the comparison of priors, we did not focus on building the most efficient VQ-VAE, but rather a reasonable model in terms of size and efficiency. CIFAR10. The CIFAR dataset consists of inputs x of dimensions 32 32 with 3 channels. The encoder projects the input into a grid of continuous values z0 e of dimension 8 8 128. After discretisation, {zt q}0 t T are in a discrete latent space induced by the VQ-VAE which consists of values in {1, . . . , K}8 8 with K = 256. The pre-trained VQ-VAE reconstructions can be seen in Figure 13 in Appendix F. Diffusion bridges for VQVAE mini Image Net. mini Image Net was introduced by (Vinyals et al., 2016) to offer more complexity than CIFAR10, while still fitting in memory of modern machines. 600 images were sampled for 100 different classes from the original Image Net dataset, then scaled down, to obtain 60,000 images of dimension 84 84. In our experiments, we trained a VQVAE model to project those input images into a grid of continuous values z0 e of dimensions 21 21 32, see Figure 15 in Appendix F. The associated codebook contains K = 128 vectors of dimension 32. Prior models. Once the VQ-VAE is trained on the mini Image Net and CIFAR datasets, the 84 84 3 and 32 32 3 images respectively are passed to the encoder and result in 21 21 and 8 8 feature maps respectively. From this model, we extract the discrete latent states from training samples to train a Pixel CNN prior and the continuous latent states for our diffusion. Concerning our diffusion prior, we choose the Ornstein-Uhlenbeck process setting η = 2, z = 0 and ϑ = 1, with T = 1000. End-to-End Training. As an additional experiment, we propose an End-to-End training of the VQ-VAE and the diffusion process. To speed up training, we first start by pretraining the VQ-VAE, then learn the parameters of our diffusion prior alongside all the VQ-VAE parameters (encoder, decoder and codebooks). Note that in this setup, we cannot directly compare the NLL to Pixel CNN or our previous diffusion model as the VQ-VAE has changed, but we can compare image generation metrics such as FID and sample quality. 4.3. Quantitative results We benchmarked our model using three metrics, in order to highlight the performances of the proposed prior, the quality of produced samples as well as the associated computation costs. Results are given as a comparison to the original Pixel CNN prior for both the mini Image Net (see Table 2) and the CIFAR10 (see Table 3) datasets. Negative Log Likelihood. Unlike most related papers, we are interested in computing the Negative Log Likelihood (NLL) directly in the latent space, as to evaluate the capacity of the priors to generate coherent latent maps. To this end, we mask a patch of the original latent space, and reconstruct the missing part, similar to image inpainting, following for instance (Van Oord et al., 2016). In the case of our prior, for each sample x, we mask an area of the continuous latent state z0 e, i.e. we mask some components of z0 e, and aim at sampling the missing components given the observed ones using the prior model. Let zq0 and ze0 (resp. zq0 and ze0) be the masked (resp. observed) discrete and continuous latent variables. The target conditional likelihood is 0|ze 0) = Z pθ(zq This likelihood is intractable and replaced by a simple Monte Carlo estimate ˆpθ(zq0|ze0) where ze0 pθ(ze0|ze0). Note that conditionally on ze0 the components of zq0 are assumed to be independent but ze0 are sampled jointly under pθ(ze0|ze0). As there are no continuous latent data in Pixel CNN, pθ(zq0|zq0) can be directly evaluated. Fr echet Inception Distance. We report Fr echet Inception Distance (FID) scores by sampling a latent discrete state zq EN from the prior, and computing the associated image through the VQ-VAE decoder. In order to evaluate each prior independently from the encoder and decoder networks, these samples are compared to VQ-VAE reconstructions of the dataset images. Kullback-Leibler divergence. In this experiment, we draw M = 1000 samples from test set and encode them using the trained VQ-VAE, and then draw as many samples from the pixel CNN prior, and our diffusion prior. We propose then to compute the empirical Kullback Leibler (KL) divergence between original and sampled distribution at each pixel. Figure 4 highlights that Pixel CNN performs poorly on the latest pixels (at the bottom) while our method remains consistent. This is explained by our denoising process in the continuous space which uses all pixels jointly while Pixel CNN is based on an autoregressive model. Figure 4. KL Distance between the true empirical distribution and both prior distributions in the latent space. Darker squares indicates lower (better) values. Ours 0.713 Pixel CNN 0.809 Table 1. Averaged KL metric on the feature map. Diffusion bridges for VQVAE Table 2. Results on mini Image Net. Metrics are computed on the validation dataset. The means are displayed along with the standard deviation in parenthesis. NLL FID s/sample Pixel CNN (Oord et al., 2017) 1.00 ( 0.05) 98 10.6s ( 28ms) Ours 0.94 ( 0.02) 99 1.7s ( 10ms) Table 3. Results on CIFAR10. Metrics are computed on the validation dataset. The means are displayed along with the standard deviation in parenthesis. NLL FID s/sample Pixel CNN (Oord et al., 2017) 1.41 ( 0.06) 109 0.21 ( 0.8ms) Ours 1.33 ( 0.18) 104 0.05s ( 0.5ms) Ours end-to-end 1.59 ( 0.27)1 92 0.11s ( 0.5ms) Computation times. We evaluated the computation cost of sampling a batch of 32 images, on a GTX TITAN Xp GPU card. Note that the computational bottleneck of our model consists of the T = 1000 sequential diffusion steps (rather than the encoder/decoder which are very fast in comparison). Therefore, a diffusion speeding technique such as the one described in (Song et al., 2021) would be straightforward to apply and would likely provide a 50 speedup as mentioned in the paper. 4.4. Qualitative results (a) Samples from our diffusion prior. (b) Samples from the Pixel CNN prior. Figure 5. Comparison between samples from our diffusion-based prior (top) and Pixel CNN prior (bottom). Sampling from the prior. Samples from the Pixel CNN prior are shown in Figure 5b and samples from our prior in Figure 5a. Additional samples are given in Appendix F. Note that contrary to original VQ-VAE prior, the prior is Figure 6. Sampling denoising chain from t = 500 up to t = 0, shown at regular intervals, conditioned on the outer part of the picture. We show only the last 500 steps of this process, as the first 500 steps are not visually informative. The sampling procedure is described in Appendix B. Figure 7. Sampling denoising chain from t = 500 up to t = 0, shown at regular intervals, unconditional. We show only the last 500 steps of this process, as the first 500 steps are not visually informative. The sampling procedure is described in Algorithm 2 not conditioned on a class, which makes the generation less specific and more difficult. However, the produced samples illustrate that our prior can generate a wide variety of images which show a large-scale spatial coherence in comparison with samples from Pixel CNN. Conditional sampling. As explained in Section 4.3, for each sample x, we mask some components of z0 e(x), and aim at sampling the missing components given the observed ones using the prior models. This conditional denoising process is further explained for our model in Appendix B. To illustrate this setting, we show different conditional samples for 3 images in Figure 8 and Figure 9 for both the Pixel CNN prior and ours. In Figure 8, the mask corresponds to a 9 9 centered square over the 21 21 feature map. In Figure 9, the mask corresponds to a 9 9 top left square. These figures illustrate that our diffusion model is much less sensitive to the selected masked region than Pixel CNN. This may be explained by the use of our denoising function εθ which depends on all conditioning pixels while Pixel CNN uses a hierarchy of masked convolutions to enforce a specific conditioning order. Additional conditional sampling experiments are given in Appendix F. Diffusion bridges for VQVAE Figure 8. Conditional sampling with centered mask: for each of the 3 different images, samples from our diffusion are on top and from Pixel CNN on the bottom. For each row: the image on the left is the VQVAE masked reconstruction, the image on the right is the full VQ-VAE reconstruction. Images in-between are independent conditional samples from the models. Figure 9. Conditional sampling with top left mask: for each of the 3 different images, samples from our diffusion are on top and from Pixel CNN on the bottom. For each row: the image on the left is the VQVAE masked reconstruction, the image on the right is the full VQ-VAE reconstruction. Images in-between are independent conditional samples from the models. Denoising chain. In addition to the conditional samples, Figure 6 shows the conditional denoising process at regularly spaced intervals, and Figure 7 shows unconditional denoising. Each image of the chain is generated by passing the predicted zt q through the VQ-VAE decoder. 5. Conclusion This work introduces a new mathematical framework for VQ-VAEs which includes a diffusion probabilistic model to learn the dependencies between the continuous latent variables alongside the encoding and decoding part of the model. We showed conceptual improvements of our model over the VQ-VAE prior, as well as first numerical results on middle scale image generation. We believe that these first numerical experiments open up many research avenues: scaling to larger models, optimal scaling of the hyperparameters, including standard tricks from other diffusion methods, studying the influence of regulazation loss for end-to-end training, etc. We hope that this framework will serve as a sound and stable foundation to derive future generative models. Acknowledgements The work of Max Cohen was supported by grants from R egion Ile-de-France. Charles Ollion and Guillaume Quispe benefited from the support of the Chair New Gen Ret AIl led by l X Ecole Polytechnique and the Fondation de l Ecole Polytechnique, sponsored by Carrefour. Diffusion bridges for VQVAE Austin, J., Johnson, D. D., Ho, J., Tarlow, D., and van den Berg, R. Structured denoising diffusion models in discrete state-spaces. Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/ forum?id=h7-Xix PCAL. Beskos, A., Roberts, G., Stuart, A., and Voss, J. Mcmc methods for diffusion bridges. Stochastics and Dynamics, 8(03):319 350, 2008. Bladt, M., Finch, S., and Sørensen, M. Simulation of multivariate diffusion bridges. Journal of the Royal Statistical Society: Series B: Statistical Methodology, pp. 343 369, 2016. Chen, X., Mishra, N., Rohaninejad, M., and Abbeel, P. Pixelsnail: An improved autoregressive generative model. In International Conference on Machine Learning, pp. 864 872. PMLR, 2018. De Bortoli, V., Doucet, A., Heng, J., and Thornton, J. Simulating diffusion bridges with score matching. ar Xiv preprint ar Xiv:2111.07243, 2021. Esser, P., Rombach, R., and Ommer, B. Taming transformers for high-resolution image synthesis. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12873 12883, 2021. Gu, S., Chen, D., Bao, J., Wen, F., Zhang, B., Chen, D., Yuan, L., and Guo, B. Vector quantized diffusion model for text-to-image synthesis. ar Xiv preprint, 2021. Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems (Neur IPS 2021), 34, 2020. Hoogeboom, E., Nielsen, D., Jaini, P., Forr e, P., and Welling, M. Argmax flows and multinomial diffusion: Learning categorical distributions. Advances in Neural Information Processing Systems (Neur IPS 2021), 34, 2021. Lin, M., Chen, R., and Mykland, P. On generating monte carlo samples of continuous diffusion bridges. Journal of the American Statistical Association, 105(490):820 838, 2010. Mittal, G., Engel, J., Hawthorne, C., and Simon, I. Symbolic music generation with diffusion models. ar Xiv preprint ar Xiv:2103.16091, 2021. Oord, A. v. d., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A., and Kavukcuoglu, K. Wavenet: A generative model for raw audio. ar Xiv preprint ar Xiv:1609.03499, 2016. Oord, A. v. d., Vinyals, O., and Kavukcuoglu, K. Neural discrete representation learning. Advances in neural information processing systems (Neur IPS 2017), 2017. Ramesh, A., Pavlov, M., Goh, G., Gray, S., Voss, C., Radford, A., Chen, M., and Sutskever, I. Zero-shot text-toimage generation. 139:8821 8831, 18 24 Jul 2021. Razavi, A., van den Oord, A., and Vinyals, O. Generating diverse high-fidelity images with vq-vae-2. In Advances in neural information processing systems (Neur IPS 2019), pp. 14866 14876, 2019. Salimans, T., Karpathy, A., Chen, X., and Kingma, D. P. Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications. 2017. Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N., and Ganguli, S. Deep unsupervised learning using nonequilibrium thermodynamics. 37:2256 2265, 2015. URL https://proceedings.mlr.press/v37/ sohl-dickstein15.html. Song, J., Meng, C., and Ermon, S. Denoising diffusion implicit models. 2021. URL https://openreview. net/forum?id=St1giar CHLP. Song, Y. and Ermon, S. Generative modeling by estimating gradients of the data distribution. 32, 2019. URL https://proceedings. neurips.cc/paper/2019/file/ 3001ef257407d5a371a96dcd947c7d93-Paper. pdf. Vahdat, A., Kreis, K., and Kautz, J. Score-based generative modeling in latent space. 2021. van den Oord, A., Kalchbrenner, N., Vinyals, O., Espeholt, L., Graves, A., and Kavukcuoglu, K. Conditional image generation with pixelcnn decoders. 2016. Van Oord, A., Kalchbrenner, N., and Kavukcuoglu, K. Pixel recurrent neural networks. In International Conference on Machine Learning, pp. 1747 1756. PMLR, 2016. Vinyals, O., Blundell, C., Lillicrap, T., kavukcuoglu, k., and Wierstra, D. Matching networks for one shot learning. In Lee, D., Sugiyama, M., Luxburg, U., Guyon, I., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings. neurips.cc/paper/2016/file/ 90e1357833654983612fb05e3ec9148c-Paper. pdf. Willetts, M., Miscouridou, X., Roberts, S., and Holmes, C. Relaxed-responsibility hierarchical discrete VAEs. Ar Xiv:2007.07307, 2021. Diffusion bridges for VQVAE A. Details on the loss function Proof of Lemma 3.1. By definition, L(θ, φ) = Eqφ log pθ(z0:T q , z0:T e , x) qφ(z0:T q , z0:T e |x) which yields L(θ, φ) = Eqφ log px θ(x|z0 q) + Eqφ log pzq θ (z0:T q |z0:T e ) qzq φ (z0:T q |z0:T e ) log pze θ (z0:T e ) qze φ (z0:T e |x) The last term may be decomposed as log pze θ (z0:T e ) qze φ (z0:T e |x) = Eqφ h log pze θ,T (z T e ) i + log pze θ,t 1|t(zt 1 e |zt e) qze φ,t|t 1(zte|zt 1 e ) log pze θ (z0:T e ) qze φ (z0:T e |x) = Eqφ h log pze θ,T (z T e ) i + Eqφ log pze θ,0|1(z0 e|z1 e) qze φ,1|0(z1e|z0e) log pze θ,t 1|t(zt 1 e |zt e) qze φ,t|t 1(zte|zt 1 e ) log pze θ (z0:T e ) qze φ (z0:T e |x) log pze θ,T (z T e ) qze φ,T |0(z Te |z0e) log pze θ,t 1|t(zt 1 e |zt e) qze φ,t 1|0,t(zt 1 e |z0e, zte) + Eqφ h log pze θ,0|1(z0 e|z1 e) i , which concludes the proof. B. Inpainting diffusion sampling We consider the case in which we know a sub-part of the picture X, and want to predict the complementary pixels X. Knowing the corresponding n latent vectors ze0 which result from X through the encoder, we sample N n ze T from the uninformative distribution ze T N(0, (2ϑ) 1η2Id (N n)). In order to produce the chain of samples zt 1 e from zt e we then follow the following procedure. zet 1 is predicted from zt e using the neural network predictor, similar to the unconditioned case. Sample zet 1 using the forward bridge noising process. C. Additional regularisation considerations We consider here details about the parameterisation of pzq θ (zt q|zt e) and qzq φ (zt q|zt e) in order to compute Lreg t (θ, φ). Using the Gumbel-Softmax formulation provides an efficient and differentiable parameterisation. pzq θ,t(zt q = |zt e) = Softmax{( ze ek 2 2 + Gk)/τt}1 k K , qφ,t(zt q = |zt e) = Softmax{( ze ek 2 2 + Gk)/τ}1 k K , where {(Gk, Gk)}1 k K are i.i.d. with distribution Gumbel(0, 1), τ > 0, and {τt}0 t T are positive time-dependent scaling parameters. Then, up to the additive normalizing terms, Lreg t (θ, φ) = Eqφ log pzq θ,t(zt q|zt e) qzq φ,t(ztq|zte) zt e bztq 2 2 Gk where bztq qzq φ,t(zt q|zt e). Considering only the first term which depend on zt e and produce non-zero gradients, we get: Lreg t (θ, φ) = γt zt e bztq 2 2 where γt = 1/τt + 1/τ drives the behavior of the regulariser. By choosing is γt negative for large t, the regulariser pushes the codebooks away from zt e, which prevents too early specialization, or matching of codebooks with noise, as zt T e is close to the uninformative distribution. Finally, for small t, choosing γt positive helps matching codebooks with ze when the corruption is small. In practice τ = 1 and a simple schedule from 10 to 0.1 for τt was considered in this work. Diffusion bridges for VQVAE D. Neural Networks For εθ(zt e, t), we use a U-net like architecture similar to the one mentioned in (Ho et al., 2020). It consists of a deep convolutional neural network with 57M parameters, which is slightly below the Pixel CNN architecture (95.8M parameters). The VQ-VAE encoder / decoders are also deep convolutional networks totalling 65M parameters. E. Toy Example Appendix Parameterisation We consider a neural network to model εθ(zt e, t). The network shown in Figure 10 consists of a time embedding similar to (Ho et al., 2020), as well as a few linear or 1D-convolutional layers, totalling around 5000 parameters. Figure 10. Graphical representation of the neural network used for the toy dataset. For the parameterisation of the quantization part, we choose pzq θ,t(zt q = ej|zt e) = Softmax1 k K{ ze ek 2}j, and the same parameterisation for qzq φ,t(zt q|zt e). Therefore our loss simplifies to: L(θ, φ) = Eqφ log px θ(x|z0 q) + Lt(θ, φ) , where t is sampled uniformly in {0, . . . , T}. t NN sequence 50 (0, 7, 3, 6, 2) 40 (6, 5, 5, 5, 3) 30 (5, 5, 5, 4, 2) 20 (6, 6, 5, 4, 3) 10 (5, 6, 5, 4, 3) 0 (5, 6, 5, 4, 3) Table 4. Discrete samples during diffusion process. The discrete sequence is obtained by computing the nearest neighbour centroid µj for each Xt s. At t = 0, X0 is sampled from a centered Gaussian distribution with small covariance matrix (2ϑ) 1η2I2 5, resulting in a uniform discrete sequence, as all centroids have a similar unit norm. Discrete samples during diffusion process Discrete sequences corresponding to the denoising diffusion process shown in Figure 3 are shown in Table 4. End-to-end training In order to train the codebooks alongside the diffusion process, we need to backpropagate the gradient of the likelihood of the data ze given a z0 e reconstructed by the diffusion process (corresponding to Lrec(θ, φ)). We use the Gumbel-Softmax parameterisation in order to obtain a differentiable process and update the codebooks ej. In this toy example, the use of the third part of the loss PT t=0 Lreg t (θ, φ) is not mandatory as we obtain good results with Lreg t (θ, φ) = 0, which means parametrising pzq θ,t(zt q|zt e) = qzq φ,t(zt q|zt e). However we noticed that Lreg t (θ, φ) is useful to improve the learning of the codebooks. If we choose γt to be decreasing with time t, we have the following. When t is low, the denoising process is almost over, Lreg t (θ, φ) pushes ze and the selected zq to close together: ze 1, then zt e will be likely near a specific ej and far from the others; therefore only a single codebook is selected and receives gradient. When t is high, zt e 0 and the Gumbel-Softmax makes it so that all codebooks are equidistant from zt e and receive non-zero gradient. This naturally solves training problem associated with dead codebooks in VQ-VAEs. Joint training of the denoising and codebooks yield excellent codebook positionning as shown in Figure 11. Diffusion bridges for VQVAE Figure 11. Left, initial random codebooks positions. Right, after training, position of codebook vectors. Note that the codebook indexes do not match the indexes of the Gaussians, the model learnt to make the associations between neighboring centroids in a different order. Toy Diffusion inpainting We consider a case in which we want to reconstruct an x while we only know one (or a few) dimensions, and sample the others. Consider that x is generated using a sequence q = (q1, q2, q , q4, q5) where the last one if fixed q1 = 0, q5 = 4. Then, knowing q1, q5, we sample q2, q3, q4, as shown in Figure 12. Figure 12. Three independent sampling of X using a trained diffusion bridge, with fixed q1 = 0, q5 = 4. The three corresponding sequences are (0, 7, 6, 5, 4), (0, 1, 2, 3, 4), (0, 7, 6, 5, 4) all valid sequences. F. Additional visuals Figure 13. Reconstruction of the VQVAE model used in the following benchmarks. F.2. Mini Image Net Diffusion bridges for VQVAE Figure 14. Samples from the Pixel CNN prior (left) and from our diffusion prior (right) on CIFAR10. Figure 15. Reconstruction of the trained VQ-VAE on the mini Image Net dataset. Original images are encoded, discretised, and decoded. Figure 16. Samples from our model for the miniimagenet dataset Diffusion bridges for VQVAE Figure 17. Conditional sampling: Top: reconstructions from the vqvae of originals images, Middle: conditional sampling with the left side of the image as condition, for our model. Bottom 1 and 2: conditional sampling in the same context with the Pixel CNN prior. Figure 18. Sampling denoising chain from up to t = 0, shown at regular intervals, conditioned on the left part of the picture. The sampling procedure is described in Appendix B. Diffusion bridges for VQVAE Figure 19. Conditional sampling with the Pixel CNN prior. Left: original images, Right: conditional sampling with the left side of the image as condition. Each row represents a class of the validation set of the mini Image Net dataset.